{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import os\n",
    "import sys\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate Toy Dataset: Sorting\n",
    "#### Prime numbers as simulated emotion words and emotion categories:\n",
    "- ECM should use external memory when predicting primes\n",
    "- Primes are equally split into (num_emo) \"emotion\" categories\n",
    "- sequence with most primes from a certain category is tagged with that \"emotion\" category"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check if a number is a prime\n",
    "def is_prime(n):\n",
    "    if n % 2 == 0:\n",
    "        return False\n",
    "    else:\n",
    "        for i in range(3, int(np.sqrt(n)) + 1):\n",
    "            if n % i == 0:\n",
    "                return False\n",
    "    return True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "N = 1000\n",
    "num_emo = 4\n",
    "\n",
    "nums = np.arange(N)\n",
    "check_prime = np.vectorize(is_prime)\n",
    "primes = nums[check_prime(nums)]\n",
    "\n",
    "# equally split primes into categories\n",
    "s_primes = np.array_split(primes, num_emo)\n",
    "s_primes = [s_p.tolist() for s_p in s_primes]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train\n",
    "source_data = []\n",
    "target_data = []\n",
    "choice_data = []\n",
    "category_data = []\n",
    "num_data = 20000\n",
    "\n",
    "for i in range(num_data):\n",
    "    length = 15 + np.random.choice(11)\n",
    "    s = np.random.choice(N, length)\n",
    "    t = np.sort(s)\n",
    "    # 1: emotion words/primes, 0: generic words\n",
    "    q = check_prime(t).astype(np.int)\n",
    "\n",
    "    counts = np.sum([[(w_t in s_p) for s_p in s_primes] for w_t in t], axis=0)\n",
    "    category = np.argmax(counts)\n",
    "\n",
    "    source_data.append(\" \".join(s.astype(str).tolist()))\n",
    "    target_data.append(\" \".join(t.astype(str).tolist()))\n",
    "    choice_data.append(\" \".join(q.astype(str).tolist()))\n",
    "    category_data.append(str(category))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(data={\"0\": source_data})\n",
    "tdf = pd.DataFrame(data={\"0\": target_data})\n",
    "qdf = pd.DataFrame(data={\"0\": choice_data})\n",
    "cdf = pd.DataFrame(data={\"0\": category_data})\n",
    "\n",
    "if not os.path.exists(\"./example/\"):\n",
    "    os.makedirs(\"./example/\")\n",
    "\n",
    "df.to_csv(\"./example/train_source.txt\", header=None, index=None)\n",
    "tdf.to_csv(\"./example/train_target.txt\", header=None, index=None)\n",
    "qdf.to_csv(\"./example/train_choice.txt\", header=None, index=None)\n",
    "cdf.to_csv(\"./example/train_category.txt\", header=None, index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dev\n",
    "dev_source_data = []\n",
    "dev_target_data = []\n",
    "dev_choice_data = []\n",
    "dev_category_data = []\n",
    "dev_num_data = 1000\n",
    "\n",
    "for i in range(dev_num_data):\n",
    "    length = 15 + np.random.choice(11)\n",
    "    s = np.random.choice(N, length)\n",
    "    t = np.sort(s)\n",
    "    q = check_prime(t).astype(np.int)\n",
    "\n",
    "    counts = np.sum([[(w_t in s_p) for s_p in s_primes] for w_t in t], axis=0)\n",
    "    category = np.argmax(counts)\n",
    "\n",
    "    dev_source_data.append(\" \".join(s.astype(str).tolist()))\n",
    "    dev_target_data.append(\" \".join(t.astype(str).tolist()))\n",
    "    dev_choice_data.append(\" \".join(q.astype(str).tolist()))\n",
    "    dev_category_data.append(str(category))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "ddf = pd.DataFrame(data={\"0\": dev_source_data})\n",
    "tddf = pd.DataFrame(data={\"0\": dev_target_data})\n",
    "qddf = pd.DataFrame(data={\"0\": dev_choice_data})\n",
    "cddf = pd.DataFrame(data={\"0\": dev_category_data})\n",
    "\n",
    "ddf.to_csv(\"./example/dev_source.txt\", header=None, index=None)\n",
    "tddf.to_csv(\"./example/dev_target.txt\", header=None, index=None)\n",
    "qddf.to_csv(\"./example/dev_choice.txt\", header=None, index=None)\n",
    "cddf.to_csv(\"./example/dev_category.txt\", header=None, index=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Encoder and ECM Wrapper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_embeddings(vocab_size, embed_size, dtype=tf.float32,\n",
    "                    initializer=None, initial_values=None,\n",
    "                    name='embeddings'):\n",
    "    \"\"\"\n",
    "    embeddings:\n",
    "        initialize trainable embeddings or load pretrained from files\n",
    "    \"\"\"\n",
    "    with tf.variable_scope(name):\n",
    "        if initial_values:\n",
    "            embeddings = tf.Variable(initial_value=initial_values,\n",
    "                                     name=\"embeddings\", dtype=dtype)\n",
    "        else:\n",
    "            if initializer is None:\n",
    "                initializer = tf.contrib.layers.xavier_initializer()\n",
    "\n",
    "            embeddings = tf.Variable(\n",
    "                initializer(shape=(vocab_size, embed_size)),\n",
    "                name=\"embeddings\", dtype=dtype)\n",
    "\n",
    "        # id_0 represents SOS token, id_1 represents EOS token\n",
    "        se_embed = tf.get_variable(\"SOS/EOS\", [2, embed_size], dtype)\n",
    "        # id_2 represents constant all zeros\n",
    "        zero_embed = tf.zeros(shape=[1, embed_size])\n",
    "        embeddings = tf.concat([se_embed, zero_embed, embeddings], axis=0)\n",
    "\n",
    "    return embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Encoder\n",
    "from tensorflow.contrib.rnn import LSTMStateTuple\n",
    "\n",
    "def create_cell(num_units, cell_type, forget_bias=1.0):\n",
    "    \"\"\"\n",
    "    Cell: build a recurrent cell\n",
    "        num_units: number of hidden cell units\n",
    "        cell_type: LSTM, GRU, LN_LSTM (layer_normalize)\n",
    "    \"\"\"\n",
    "    if cell_type == \"LSTM\":\n",
    "        cell = tf.nn.rnn_cell.BasicLSTMCell(num_units, forget_bias=forget_bias)\n",
    "\n",
    "    elif cell_type == \"GRU\":\n",
    "        cell = tf.nn.rnn_cell.GRUCell(num_units)\n",
    "\n",
    "    elif cell_type == \"LN_LSTM\":\n",
    "        cell = tf.contrib.rnn.LayerNormBasicLSTMCell(\n",
    "            num_units,\n",
    "            forget_bias=forget_bias,\n",
    "            layer_norm=True)\n",
    "\n",
    "    else:\n",
    "        raise ValueError(\"Unknown cell type %s\" % cell_type)\n",
    "\n",
    "    return cell\n",
    "\n",
    "\n",
    "def build_rnn_cell(num_layers, num_units, cell_type, forget_bias=1.0):\n",
    "    \"\"\"\n",
    "    RNN_cell: build a multi-layer rnn cell\n",
    "        num_layers: number of hidden layers\n",
    "    \"\"\"\n",
    "    cell_seq = []\n",
    "    for i in range(num_layers):\n",
    "        cell = create_cell(num_units, cell_type, forget_bias)\n",
    "        cell_seq.append(cell)\n",
    "\n",
    "    if num_layers > 1:\n",
    "        rnn_cell = tf.nn.rnn_cell.MultiRNNCell(cell_seq)\n",
    "    else:\n",
    "        rnn_cell = cell_seq[0]\n",
    "\n",
    "    return rnn_cell\n",
    "\n",
    "\n",
    "def build_encoder(embeddings, source_ids, num_layers, num_units, cell_type,\n",
    "                  forget_bias=1.0, bidir=False, time_major=False,\n",
    "                  dtype=tf.float32, name=\"encoder\"):\n",
    "    \"\"\"\n",
    "    encoder: build rnn encoder for Seq2seq\n",
    "        source_ids: [batch_size, max_time]\n",
    "        bidir: bidirectional or unidirectional\n",
    "\n",
    "    Returns:\n",
    "        encoder_outputs: [batch_size, max_time, num_units]\n",
    "        encoder_states: (StateTuple(shape=(batch_size, num_units)), ...)\n",
    "    \"\"\"\n",
    "    with tf.variable_scope(name):\n",
    "        if time_major:\n",
    "            source_ids = tf.transpose(source_ids)\n",
    "\n",
    "        # embedding lookup, embed_inputs: [max_time, batch_size, num_units]\n",
    "        embed_inputs = tf.nn.embedding_lookup(embeddings, source_ids)\n",
    "\n",
    "        # bidirectional\n",
    "        if bidir:\n",
    "            encoder_states = []\n",
    "            layer_inputs = embed_inputs\n",
    "\n",
    "            # build rnn layer-by-layer\n",
    "            for i in range(num_layers):\n",
    "                with tf.variable_scope(\"layer_%d\" % (i + 1)):\n",
    "                    fw_cell = build_rnn_cell(\n",
    "                        1, num_units, cell_type, forget_bias)\n",
    "                    bw_cell = build_rnn_cell(\n",
    "                        1, num_units, cell_type, forget_bias)\n",
    "\n",
    "                    dyn_rnn = tf.nn.bidirectional_dynamic_rnn(\n",
    "                        fw_cell, bw_cell, layer_inputs,\n",
    "                        time_major=time_major,\n",
    "                        dtype=dtype,\n",
    "                        swap_memory=True)\n",
    "                    bi_outputs, (state_fw, state_bw) = dyn_rnn\n",
    "\n",
    "                    # handle cell memory state\n",
    "                    if cell_type == \"LSTM\":\n",
    "                        state_c = state_fw.c + state_bw.c\n",
    "                        state_h = state_fw.h + state_bw.h\n",
    "                        encoder_states.append(LSTMStateTuple(state_c, state_h))\n",
    "                    else:\n",
    "                        encoder_states.append(state_fw + state_bw)\n",
    "\n",
    "                    # concat and map as inputs of next layer\n",
    "                    layer_inputs = tf.layers.dense(\n",
    "                        tf.concat(bi_outputs, -1), num_units)\n",
    "\n",
    "            encoder_outputs = layer_inputs\n",
    "            encoder_states = tuple(encoder_states)\n",
    "\n",
    "        # unidirectional\n",
    "        else:\n",
    "            rnn_cell = build_rnn_cell(\n",
    "                num_layers, num_units, cell_type, forget_bias)\n",
    "\n",
    "            encoder_outputs, encoder_states = tf.nn.dynamic_rnn(\n",
    "                rnn_cell, embed_inputs,\n",
    "                time_major=time_major,\n",
    "                dtype=dtype,\n",
    "                swap_memory=True)\n",
    "\n",
    "    return encoder_outputs, encoder_states"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = init_embeddings(1000, 128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "source_ids = tf.placeholder(tf.int32, [None, None])\n",
    "encoder_outputs, encoder_states = build_encoder(embeddings, source_ids, num_layers=2,\n",
    "                                                num_units=256, cell_type=\"LSTM\", bidir=True,\n",
    "                                                name=\"e8\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### ECM Wrapper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ECM wrapper\n",
    "from tensorflow.contrib.rnn import RNNCell\n",
    "\n",
    "import collections\n",
    "\n",
    "\n",
    "ECMState = collections.namedtuple(\n",
    "    \"ECMState\", (\"cell_states\", \"h\", \"context\", \"internal_memory\"))\n",
    "\n",
    "\n",
    "class ECMWrapper(RNNCell):\n",
    "    \"\"\"\n",
    "    Emotion Chatting Machine: H. Zhou, et al. AAAI 2018\n",
    "    (https://arxiv.org/abs/1704.01074)\n",
    "    Emotion Category Embedding, Internal and External Memory Modules\n",
    "        cell: vanilla multi-layer RNNCell\n",
    "        memory: [batch_size, max_time, num_units]\n",
    "        emo_cat_embs: category embeddings, [batch_size, emo_cat_units]\n",
    "        emo_cat: emotion category, [batch_size]\n",
    "        emo_int_units: dimension of internal emotion memory\n",
    "    \"\"\"\n",
    "    def __init__(self, cell, memory, dec_init_states, num_hidden,\n",
    "                 num_units, dtype, emo_cat_embs, emo_cat, num_emo,\n",
    "                 emo_int_units, emo_init=None):\n",
    "        self._cell = cell\n",
    "        self._memory = memory\n",
    "        self.num_hidden = num_hidden\n",
    "\n",
    "        self._dec_init_states = dec_init_states\n",
    "        self._state_size = ECMState(self._cell.state_size,\n",
    "                                    num_units, memory.shape[-1].value,\n",
    "                                    emo_int_units)\n",
    "        self._num_units = num_units\n",
    "        self._dtype = dtype\n",
    "\n",
    "        # ECM hyperparameters\n",
    "        self._emo_cat_embs = emo_cat_embs\n",
    "        self._emo_cat = emo_cat\n",
    "        self._emo_int_units = emo_int_units\n",
    "\n",
    "        # internal memory\n",
    "        if emo_init is None:\n",
    "            initializer = tf.contrib.layers.xavier_initializer()\n",
    "\n",
    "        self.int_memory = tf.Variable(\n",
    "            initializer(shape=(num_emo, emo_int_units)),\n",
    "            name=\"emo_memory\", dtype=dtype)\n",
    "\n",
    "        self.read_g = tf.layers.Dense(\n",
    "            emo_int_units, use_bias=False, name=\"internal_read_gate\")\n",
    "        self.write_g = tf.layers.Dense(\n",
    "            emo_int_units, use_bias=False, name=\"internal_write_gate\")\n",
    "\n",
    "    @property\n",
    "    def state_size(self):\n",
    "        return self._state_size\n",
    "\n",
    "    @property\n",
    "    def output_size(self):\n",
    "        return self._num_units\n",
    "\n",
    "    def initial_state(self):\n",
    "        \"\"\"\n",
    "        Generate initial state for ECM wrapped rnn cell\n",
    "            dec_init_states: None (no states pass), or encoder final states\n",
    "            num_units: decoder's num of cell units\n",
    "        Returns:\n",
    "            h_0: [batch_size, num_units]\n",
    "            context_0: [batch_size, num_units]\n",
    "            M_emo_0: [batch_size, emo_int_units]\n",
    "        \"\"\"\n",
    "        h_0 = tf.zeros([1, self._num_units], self._dtype)\n",
    "        context_0 = self._compute_context(h_0)\n",
    "        h_0 = context_0 * 0\n",
    "        M_emo_0 = tf.gather(self.int_memory, self._emo_cat)\n",
    "\n",
    "        if self._dec_init_states is None:\n",
    "            batch_size = tf.shape(self._memory)[0]\n",
    "            cell_states = self._cell.zero_state(batch_size, self._dtype)\n",
    "        else:\n",
    "            cell_states = self._dec_init_states\n",
    "\n",
    "        ecm_state_0 = ECMState(cell_states, h_0, context_0, M_emo_0)\n",
    "\n",
    "        return ecm_state_0\n",
    "\n",
    "    def _compute_context(self, query):\n",
    "        \"\"\"\n",
    "        Compute attn scores and weighted sum of memory as the context\n",
    "            query: [batch_size, num_units]\n",
    "        Returns:\n",
    "            context: [batch_size, num_units]\n",
    "        \"\"\"\n",
    "        query = tf.expand_dims(query, -2)\n",
    "        Wq = tf.layers.dense(query, self.num_hidden, use_bias=False)\n",
    "        Wm = tf.layers.dense(self._memory, self.num_hidden, use_bias=False)\n",
    "        e = tf.layers.dense(tf.nn.tanh(Wm + Wq), 1, use_bias=False)\n",
    "        attn_scores = tf.expand_dims(tf.nn.softmax(tf.squeeze(e, axis=-1)), -1)\n",
    "\n",
    "        context = tf.reduce_sum(attn_scores * self._memory, axis=1)\n",
    "\n",
    "        return context\n",
    "\n",
    "    def _read_internal_memory(self, M_emo, read_inputs):\n",
    "        \"\"\"\n",
    "        Read the internal memory\n",
    "            M_emo: [batch_size, emo_int_units]\n",
    "            read_inputs: [batch_size, d]\n",
    "        Returns:\n",
    "            M_read: [batch_size, emo_int_units]\n",
    "        \"\"\"\n",
    "        gate_read = tf.nn.sigmoid(self.read_g(read_inputs))\n",
    "        return (M_emo * gate_read)\n",
    "\n",
    "    def _write_internal_memory(self, M_emo, new_h):\n",
    "        \"\"\"\n",
    "        Write the internal memory\n",
    "            M_emo: [batch_size, emo_int_units]\n",
    "            new_h: [batch_size, num_units]\n",
    "        Returns:\n",
    "            M_write: [batch_size, emo_int_units]\n",
    "        \"\"\"\n",
    "        gate_write = tf.nn.sigmoid(self.write_g(new_h))\n",
    "        return (M_emo * gate_write)\n",
    "\n",
    "    def __call__(self, inputs, ecm_states):\n",
    "        \"\"\"\n",
    "            inputs: emebeddings of previous word\n",
    "            states: (cell_states, outputs, context, int_memory)\n",
    "        \"\"\"\n",
    "        prev_cell_states, h, context, M_emo = ecm_states\n",
    "\n",
    "        # read internal memory\n",
    "        read_inputs = tf.concat([inputs, h, context], axis=-1)\n",
    "        M_read = self._read_internal_memory(M_emo, read_inputs)\n",
    "\n",
    "        # pass into RNN_cell to get the output\n",
    "        x = [inputs, h, context, self._emo_cat_embs, M_read]\n",
    "        x = tf.concat(x, axis=-1)\n",
    "        new_h, cell_states = self._cell.__call__(x, prev_cell_states)\n",
    "\n",
    "        # update states\n",
    "        new_M_emo = self._write_internal_memory(M_emo, new_h)\n",
    "        new_context = self._compute_context(new_h)\n",
    "        new_ecm_states = ECMState(cell_states, new_h, new_context, new_M_emo)\n",
    "\n",
    "        return (new_h, new_ecm_states)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Decoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Beam Search Decoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ### Beam Search helpers ###\n",
    "def tile_beam(tensor, beam_size):\n",
    "    \"\"\"\n",
    "        tensor: batch-major, [batch_size, ...]\n",
    "    Returns:\n",
    "        tensor: beam_size tiled, [batch_size, beam_size, ...]\n",
    "    \"\"\"\n",
    "    tensor = tf.expand_dims(tensor, axis=1)\n",
    "    # set multiples: [1, beam_size, 1, ..., 1]\n",
    "    multiples = [1 for i in range(tensor.shape.ndims)]\n",
    "    multiples[1] = beam_size\n",
    "\n",
    "    return tf.tile(tensor, multiples)\n",
    "\n",
    "\n",
    "def merge_batch_beam(tensor):\n",
    "    \"\"\"\n",
    "        tensor: [batch_size, beam_size, ...]\n",
    "    Returns:\n",
    "        tensorL [batch_size * beam_size, ...]\n",
    "    \"\"\"\n",
    "    # tf.shape(t) handles indefinite shape\n",
    "    batch_size = tf.shape(tensor)[0]\n",
    "    # specified shape can be withdrawed right away\n",
    "    beam_size = tensor.shape[1].value\n",
    "\n",
    "    shape = list(tensor.shape)\n",
    "    shape.pop(0)\n",
    "    shape[0] = batch_size * beam_size\n",
    "\n",
    "    return tf.reshape(tensor, shape)\n",
    "\n",
    "\n",
    "def split_batch_beam(tensor, beam_size):\n",
    "    \"\"\"\n",
    "        tensor: [batch_size * beam_size, ...]\n",
    "    Returns:\n",
    "        tensor: [batch_size, beam_size, ...]\n",
    "    \"\"\"\n",
    "    shape = list(tensor.shape)\n",
    "    shape[0] = beam_size\n",
    "    shape.insert(0, -1)\n",
    "\n",
    "    return tf.reshape(tensor, shape)\n",
    "\n",
    "\n",
    "def mask_log_probs(log_probs, end_id, decode_finished):\n",
    "    \"\"\"\n",
    "    Set log_probs after end_token to be [-inf, 0, -inf, ...]\n",
    "        log_probs: [batch_size, beam_size, vocab_size]\n",
    "        decode_finished: [batch_size, beam_size]\n",
    "    \"\"\"\n",
    "    vocab_size = log_probs.shape[-1].value\n",
    "    one_hot = tf.one_hot(end_id, vocab_size, on_value=0.0,\n",
    "                         off_value=log_probs.dtype.min,\n",
    "                         dtype=log_probs.dtype)\n",
    "    I_fin = tf.expand_dims(tf.cast(decode_finished, log_probs.dtype),\n",
    "                           axis=-1)\n",
    "\n",
    "    return (1. - I_fin) * log_probs + I_fin * one_hot\n",
    "\n",
    "\n",
    "def sample_bernoulli(prob, shape):\n",
    "    \"\"\"Samples a boolean tensor with shape = s according to bernouilli\"\"\"\n",
    "    return tf.greater(prob, tf.random_uniform(shape))\n",
    "\n",
    "\n",
    "def add_diversity_penalty(log_probs, div_gamma, div_prob, batch_size,\n",
    "                          beam_size, vocab_size):\n",
    "    \"\"\"\n",
    "    Diversity penalty by Li et al. 2016\n",
    "        div_gamma: (float) diversity parameter\n",
    "        div_prob: adds penalty with div_proba\n",
    "    \"\"\"\n",
    "    if (div_gamma is None) or (div_prob is None):\n",
    "        return log_probs\n",
    "\n",
    "    if (div_gamma == 1) or (div_prob) == 0:\n",
    "        return log_probs\n",
    "\n",
    "    top_probs, top_inds = tf.nn.top_k(log_probs, k=vocab_size, sorted=True)\n",
    "\n",
    "    # inverse permutation to get rank of each entry\n",
    "    top_inds = tf.reshape(top_inds, [-1, vocab_size])\n",
    "    index_rank = tf.map_fn(tf.invert_permutation, top_inds, back_prop=False)\n",
    "    index_rank = tf.reshape(\n",
    "        index_rank, shape=[batch_size, beam_size, vocab_size])\n",
    "\n",
    "    # compute penalty\n",
    "    penalties = tf.log(div_gamma) * tf.cast(index_rank, log_probs.dtype)\n",
    "\n",
    "    # only apply penalty with some probability\n",
    "    apply_penalty = tf.cast(\n",
    "            sample_bernoulli(div_prob, [batch_size, beam_size, vocab_size]),\n",
    "            penalties.dtype)\n",
    "    penalties *= apply_penalty\n",
    "\n",
    "    return log_probs + penalties\n",
    "\n",
    "\n",
    "def gather_helper(tensor, indices, batch_size, beam_size):\n",
    "    \"\"\"\n",
    "        tensor: [batch_size, beam_size, d]\n",
    "        indices: [batch_size, beam_size]\n",
    "    Returns:\n",
    "        new_tensor: new_t[:, i] = t[:, new_parents[:, i]]\n",
    "    \"\"\"\n",
    "    range_ = tf.expand_dims(tf.range(batch_size) * beam_size, axis=1)\n",
    "    # flatten\n",
    "    indices = tf.reshape(indices + range_, [-1])\n",
    "    output = tf.gather(tf.reshape(tensor, [batch_size * beam_size, -1]),\n",
    "                       indices)\n",
    "\n",
    "    if tensor.shape.ndims == 2:\n",
    "        return tf.reshape(output, [batch_size, beam_size])\n",
    "\n",
    "    elif tensor.shape.ndims == 3:\n",
    "        d = tensor.shape[-1].value\n",
    "        return tf.reshape(output, [batch_size, beam_size, d])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Beam search decoding\n",
    "from tensorflow.contrib.framework import nest\n",
    "\n",
    "\n",
    "class DecoderOutput(collections.namedtuple(\n",
    "                    \"DecoderOutput\", (\"logits\", \"ids\"))):\n",
    "    \"\"\"\n",
    "        logits: [batch_size, vocab_size]\n",
    "        ids: [batch_size]\n",
    "    \"\"\"\n",
    "    pass\n",
    "\n",
    "\n",
    "class BeamDecoderOutput(collections.namedtuple(\n",
    "        \"BeamDecoderOutput\", (\"logits\", \"ids\", \"parents\"))):\n",
    "    \"\"\"\n",
    "        logits: [batch_size, beam_size, vocab_size]\n",
    "        ids: [batch_size, beam_size], best words ids now\n",
    "        parents: [batch_size, beam_size], previous step beam index ids\n",
    "    \"\"\"\n",
    "    pass\n",
    "\n",
    "\n",
    "class BeamDecoderCellStates(collections.namedtuple(\n",
    "        \"BeamDecoderCellStates\", (\"cell_states\", \"log_probs\"))):\n",
    "    \"\"\"\n",
    "        cell_states: [batch_size, beam_size, num_units]\n",
    "        log_probs: [batch_size, beam_size]\n",
    "    \"\"\"\n",
    "    pass\n",
    "\n",
    "\n",
    "class ECMBeamSearchDecodeCell(object):\n",
    "\n",
    "    def __init__(self, embeddings, cell, dec_init_states, output_layer,\n",
    "                 emo_output_layer, emo_choice_layer, batch_size, dtype,\n",
    "                 beam_size, vocab_size, div_gamma=None, div_prob=None):\n",
    "        \"\"\"\n",
    "            div_gamma: (float) relative weight of penalties\n",
    "            div_prob: (float) prob to apply penalties\n",
    "        \"\"\"\n",
    "        self._embeddings = embeddings\n",
    "        self._origin_vocab_size = vocab_size\n",
    "        self._vocab_size = vocab_size * 2\n",
    "        self._cell = cell\n",
    "        self._dec_init_states = dec_init_states\n",
    "\n",
    "        self._output_layer = output_layer\n",
    "        self._emo_output_layer = emo_output_layer\n",
    "        self._emo_choice_layer = emo_choice_layer\n",
    "\n",
    "        self._batch_size = batch_size\n",
    "        self._start_token = tf.nn.embedding_lookup(embeddings, 0)\n",
    "        self._end_id = 1\n",
    "        self._dtype = dtype\n",
    "\n",
    "        self._beam_size = beam_size\n",
    "        self._div_gamma = div_gamma\n",
    "        self._div_prob = div_prob\n",
    "\n",
    "        indices = np.repeat(np.arange(self._batch_size), self._beam_size)\n",
    "        if hasattr(self._cell, \"_memory\"):\n",
    "            self._cell._memory = tf.gather(self._cell._memory, indices)\n",
    "\n",
    "        if hasattr(self._cell, \"_emo_cat_embs\"):\n",
    "            self._cell._emo_cat_embs = tf.gather(\n",
    "                self._cell._emo_cat_embs, indices)\n",
    "\n",
    "        if hasattr(self._cell, \"_emo_cat\"):\n",
    "            self._emo_cat = tf.gather(self._cell._emo_cat, indices)\n",
    "\n",
    "    @property\n",
    "    def output_dtype(self):\n",
    "        \"\"\"Generate the structure for initial TensorArrays in dynamic_decode\"\"\"\n",
    "        return BeamDecoderOutput(logits=self._dtype,\n",
    "                                 ids=tf.int32, parents=tf.int32)\n",
    "\n",
    "    def _initial_state(self):\n",
    "        # t: [batch_size, num_units]\n",
    "        cell_states = nest.map_structure(\n",
    "            lambda t: tile_beam(t, self._beam_size), self._dec_init_states)\n",
    "\n",
    "        # another \"log_probs\" initial states: accumulative log_prob!\n",
    "        log_probs = tf.zeros([self._batch_size, self._beam_size],\n",
    "                             dtype=self._dtype)\n",
    "\n",
    "        return BeamDecoderCellStates(cell_states, log_probs)\n",
    "\n",
    "    def initialize(self):\n",
    "        # initial cell states\n",
    "        cell_states = self._initial_state()\n",
    "\n",
    "        # inputs: SOS, [embed_size] -> [batch_size, beam_size, embed_size]\n",
    "        inputs = tf.tile(tf.reshape(self._start_token, [1, 1, -1]),\n",
    "                         multiples=[self._batch_size, self._beam_size, 1])\n",
    "\n",
    "        # initial ending signals: [batch_size, beam_size]\n",
    "        decode_finished = tf.zeros([self._batch_size, self._beam_size],\n",
    "                                   dtype=tf.bool)\n",
    "\n",
    "        return cell_states, inputs, decode_finished\n",
    "\n",
    "    def step(self, time_index, beam_states, inputs, decode_finished):\n",
    "        \"\"\"\n",
    "            logits: [batch_size, beam_size, vocab_size]\n",
    "            ids: [batch_size, beam_size], best words ids now\n",
    "            parents: [batch_size, beam_size], previous step beam index ids\n",
    "        \"\"\"\n",
    "        # 1-1: merge batch -> [batch_size*beam_size, ...]\n",
    "        cell_states = nest.map_structure(\n",
    "            merge_batch_beam, beam_states.cell_states)\n",
    "        inputs = merge_batch_beam(inputs)\n",
    "\n",
    "        # 1-2: perform cell ops to get new log probs\n",
    "        new_h, new_cell_states = self._cell.__call__(inputs, cell_states)\n",
    "        gen_log_probs = tf.nn.log_softmax(self._output_layer(new_h))\n",
    "        emo_log_probs = tf.nn.log_softmax(self._emo_output_layer(new_h))\n",
    "        alphas = tf.nn.sigmoid(self._emo_choice_layer(new_h))\n",
    "\n",
    "        gen_log_probs = gen_log_probs + tf.log(1 - alphas)\n",
    "        emo_log_probs = emo_log_probs + tf.log(alphas)\n",
    "        raw_log_probs = tf.concat([gen_log_probs, emo_log_probs], axis=-1)\n",
    "\n",
    "        # 1-3: split batch beam -> [batch_size, beam_size, ...]\n",
    "        raw_log_probs = split_batch_beam(raw_log_probs, self._beam_size)\n",
    "        new_cell_states = nest.map_structure(\n",
    "            lambda t: split_batch_beam(t, self._beam_size), new_cell_states)\n",
    "\n",
    "        # 2-1: mask log_probs, [batch_size, beam_size, vocab_size]\n",
    "        step_log_probs = mask_log_probs(\n",
    "            raw_log_probs, self._end_id, decode_finished)\n",
    "\n",
    "        # 2-2: add cumulative log_probs and \"diversity penalty\"\n",
    "        log_probs = tf.expand_dims(beam_states.log_probs, axis=-1)\n",
    "        log_probs = log_probs + step_log_probs\n",
    "        log_probs = add_diversity_penalty(log_probs, self._div_gamma,\n",
    "                                          self._div_prob, self._batch_size,\n",
    "                                          self._beam_size, self._vocab_size)\n",
    "\n",
    "        # 3-1: flatten, if time_index = 0, consider only one beam\n",
    "        # log_probs[:, 0]: [batch_size, vocab_size]\n",
    "        shape = [self._batch_size, self._beam_size * self._vocab_size]\n",
    "        log_probs_flat = tf.reshape(log_probs, shape)\n",
    "        log_probs_flat = tf.cond(time_index > 0, lambda: log_probs_flat,\n",
    "                                 lambda: log_probs[:, 0])\n",
    "\n",
    "        # 3-2: compute the top (beam_size) beams, [batch_size, beam_size]\n",
    "        new_log_probs, indices = tf.nn.top_k(log_probs_flat, self._beam_size)\n",
    "\n",
    "        # 3-3: obtain ids and parent beams\n",
    "        new_ids = indices % self._vocab_size\n",
    "        # //: floor division, know which beam it belongs to\n",
    "        new_parents = indices // self._vocab_size\n",
    "\n",
    "        # 4-1: compute new states\n",
    "        new_inputs = tf.nn.embedding_lookup(\n",
    "            self._embeddings, (new_ids % self._origin_vocab_size))\n",
    "\n",
    "        decode_finished = gather_helper(\n",
    "            decode_finished, new_parents, self._batch_size, self._beam_size)\n",
    "\n",
    "        new_decode_finished = tf.logical_or(\n",
    "            decode_finished, tf.equal(new_ids, self._end_id))\n",
    "\n",
    "        new_cell_states = nest.map_structure(\n",
    "            lambda t: gather_helper(t, new_parents, self._batch_size,\n",
    "                                    self._beam_size), new_cell_states)\n",
    "\n",
    "        # 4-2: create new state and output of decoder\n",
    "        new_beam_states = BeamDecoderCellStates(cell_states=new_cell_states,\n",
    "                                                log_probs=new_log_probs)\n",
    "        new_output = BeamDecoderOutput(logits=raw_log_probs, ids=new_ids,\n",
    "                                       parents=new_parents)\n",
    "\n",
    "        return (new_output, new_beam_states, new_inputs, new_decode_finished)\n",
    "\n",
    "    def finalize(self, final_outputs, final_cell_states):\n",
    "        \"\"\"\n",
    "            final_outputs: [max_time, logits] structure of tensor\n",
    "            final_cell_states: BeamDecoderCellStates\n",
    "        Returns:\n",
    "            [max_time, batch_size, beam_size, ] stucture of tensor\n",
    "        \"\"\"\n",
    "        # reverse the time dimension\n",
    "        max_iter = tf.shape(final_outputs.ids)[0]\n",
    "        final_outputs = nest.map_structure(lambda t: tf.reverse(t, axis=[0]),\n",
    "                                           final_outputs)\n",
    "\n",
    "        # initial states\n",
    "        def create_ta(d):\n",
    "            return tf.TensorArray(dtype=d, size=max_iter)\n",
    "\n",
    "        f_time_index = tf.constant(0, dtype=tf.int32)\n",
    "        # final output dtype\n",
    "        final_dtype = DecoderOutput(logits=self._dtype, ids=tf.int32)\n",
    "        f_output_ta = nest.map_structure(create_ta, final_dtype)\n",
    "\n",
    "        # initial parents: [batch_size, beam_size]\n",
    "        f_parents = tf.tile(\n",
    "            tf.expand_dims(tf.range(self._beam_size), axis=0),\n",
    "            multiples=[self._batch_size, 1])\n",
    "\n",
    "        def condition(f_time_index, output_ta, f_parents):\n",
    "            return tf.less(f_time_index, max_iter)\n",
    "\n",
    "        def body(f_time_index, output_ta, f_parents):\n",
    "            # get ids, logits and parents predicted at this time step\n",
    "            input_t = nest.map_structure(lambda t: t[f_time_index],\n",
    "                                         final_outputs)\n",
    "\n",
    "            # parents: reversed version shows the next position to go\n",
    "            new_beam_state = nest.map_structure(\n",
    "                lambda t: gather_helper(t, f_parents, self._batch_size,\n",
    "                                        self._beam_size),\n",
    "                input_t)\n",
    "\n",
    "            # create new output\n",
    "            new_output = DecoderOutput(logits=new_beam_state.logits,\n",
    "                                       ids=new_beam_state.ids)\n",
    "\n",
    "            # write beam ids\n",
    "            output_ta = nest.map_structure(\n",
    "                lambda ta, out: ta.write(f_time_index, out),\n",
    "                output_ta, new_output)\n",
    "\n",
    "            return (f_time_index + 1), output_ta, input_t.parents\n",
    "\n",
    "        with tf.variable_scope(\"beam_search_decoding\"):\n",
    "            res = tf.while_loop(\n",
    "                    condition,\n",
    "                    body,\n",
    "                    loop_vars=[f_time_index, f_output_ta, f_parents],\n",
    "                    back_prop=False)\n",
    "\n",
    "        # stack the structure and reverse back\n",
    "        final_outputs = nest.map_structure(lambda ta: ta.stack(), res[1])\n",
    "        final_outputs = nest.map_structure(lambda t: tf.reverse(t, axis=[0]),\n",
    "                                           final_outputs)\n",
    "\n",
    "        return DecoderOutput(logits=final_outputs.logits,\n",
    "                             ids=final_outputs.ids)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Decoding function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dynamic decode function\n",
    "from tensorflow.contrib.framework import nest\n",
    "\n",
    "\n",
    "def transpose_batch_time(tensor):\n",
    "    ndims = tensor.shape.ndims\n",
    "    if ndims == 2:\n",
    "        return tf.transpose(tensor, [1, 0])\n",
    "\n",
    "    elif ndims == 3:\n",
    "        return tf.transpose(tensor, [1, 0, 2])\n",
    "\n",
    "    else:\n",
    "        return tf.transpose(tensor, [1, 0, 2, 3])\n",
    "\n",
    "\n",
    "# Dynamic decode function\n",
    "def dynamic_decode(decoder_cell, max_iter):\n",
    "    max_iter = tf.convert_to_tensor(max_iter, dtype=tf.int32)\n",
    "\n",
    "    # TensorArray: wrap dynamic-sized, per-time-step, write-once Tensor arrays\n",
    "    def create_tensor_array(d):\n",
    "        # initial size = 0\n",
    "        return tf.TensorArray(dtype=d, size=0, dynamic_size=True)\n",
    "\n",
    "    time_index = tf.constant(0, dtype=tf.int32)\n",
    "    # nest.map_structure: applies func to each entry in structure\n",
    "    output_tensor_arrays = nest.map_structure(\n",
    "        create_tensor_array, decoder_cell.output_dtype)\n",
    "\n",
    "    cell_states, inputs, decode_finished = decoder_cell.initialize()\n",
    "\n",
    "    # tf.while_loop(cond, body, vars): Repeat body while condition cond is true\n",
    "    def condition(time_index, output_ta, cell_states, inputs, decode_finished):\n",
    "        \"\"\"\n",
    "            if all \"decode_finished\" are True, return \"False\"\n",
    "        \"\"\"\n",
    "        return tf.logical_not(tf.reduce_all(decode_finished))\n",
    "\n",
    "    def body(time_index, output_ta, cell_states, inputs, decode_finished):\n",
    "        sts = decoder_cell.step(time_index, cell_states, inputs,\n",
    "                                decode_finished)\n",
    "        new_output, new_cell_states, new_inputs, new_decode_finished = sts\n",
    "\n",
    "        # TensorArray.write(index, value): register value and returns new TAs\n",
    "        output_ta = nest.map_structure(\n",
    "            lambda ta, out: ta.write(time_index, out),\n",
    "            output_ta, new_output)\n",
    "\n",
    "        new_decode_finished = tf.logical_or(\n",
    "            tf.greater_equal(time_index, max_iter),\n",
    "            new_decode_finished)\n",
    "\n",
    "        return (time_index + 1, output_ta, new_cell_states, new_inputs,\n",
    "                new_decode_finished)\n",
    "\n",
    "    with tf.variable_scope(\"decoding\"):\n",
    "\n",
    "        res = tf.while_loop(\n",
    "            condition,\n",
    "            body,\n",
    "            loop_vars=[time_index, output_tensor_arrays, cell_states,\n",
    "                       inputs, decode_finished],\n",
    "            back_prop=False)\n",
    "\n",
    "    # get final outputs and states\n",
    "    final_output_ta, final_cell_states = res[1], res[2]\n",
    "\n",
    "    # TA.stack(): stack all tensors in TensorArray, [max_iter+1, batch_size, _]\n",
    "    final_outputs = nest.map_structure(lambda ta: ta.stack(), final_output_ta)\n",
    "\n",
    "    # finalize the computation from the decoder cell\n",
    "    final_outputs = decoder_cell.finalize(final_outputs, final_cell_states)\n",
    "\n",
    "    # transpose the final output\n",
    "    final_outputs = nest.map_structure(transpose_batch_time, final_outputs)\n",
    "\n",
    "    return final_outputs, final_cell_states\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "\n",
    "\n",
    "def build_ECM_decoder(encoder_outputs, encoder_states, embeddings, num_layers,\n",
    "                      num_units, cell_type, num_emo, emo_cat, emo_cat_units,\n",
    "                      emo_int_units, state_pass=True, infer_batch_size=None,\n",
    "                      attn_num_units=128, target_ids=None, beam_size=None,\n",
    "                      max_iter=20, dtype=tf.float32, forget_bias=1.0,\n",
    "                      name=\"ECM_decoder\"):\n",
    "    \"\"\"\n",
    "    ECM decoder: build ECM decoder with emotion category embedding,\n",
    "             internal & external memory modules\n",
    "        target_ids: [batch_size, max_time]\n",
    "        infer_type: greedy decode or beam search\n",
    "        attention_wrap: a wrapper to enable attention mechanism\n",
    "        num_emo: number of emotions\n",
    "        emo_cat: emotion catogories, [batch_size]\n",
    "        emo_cat_units: dimension of emotion category embeddings\n",
    "        emo_int_units: dimension of emotion internal memory\n",
    "\n",
    "    Returns:\n",
    "        train_outputs: (generic_logits, emo_ext_logits, alphas, int_M_emo),\n",
    "            first 3 shape: [batch_size, max_time, d]\n",
    "            int_M_emo: [batch_size, emo_int_units]\n",
    "        infer_outputs: namedtuple(logits, ids), [batch_size, max_time, d]\n",
    "    \"\"\"\n",
    "    # parameter checking\n",
    "    if infer_batch_size is None:\n",
    "        txt = \"infer_batch_size not specified, infer output will be 'None'.\"\n",
    "        warnings.warn(txt)\n",
    "    elif beam_size is None:\n",
    "        raise ValueError(\"Inference by beam search must specify beam_size.\")\n",
    "\n",
    "    if target_ids is None:\n",
    "        txt = \"target_ids not specified, train_outputs will be 'None'.\"\n",
    "        warnings.warn(txt)\n",
    "\n",
    "    with tf.variable_scope(name):\n",
    "        vocab_size = embeddings.shape[0].value\n",
    "\n",
    "        # create emotion category embeddings\n",
    "        emo_init = tf.contrib.layers.xavier_initializer()\n",
    "        emo_cat_embeddings = tf.Variable(\n",
    "            emo_init(shape=(num_emo, emo_cat_units)),\n",
    "            name=\"emo_cat_embeddings\", dtype=dtype)\n",
    "        emo_cat_embs = tf.nn.embedding_lookup(emo_cat_embeddings, emo_cat)\n",
    "\n",
    "        # decoder rnn_cell\n",
    "        cell = build_rnn_cell(num_layers, num_units, cell_type, forget_bias)\n",
    "        dec_init_states = encoder_states if state_pass else None\n",
    "        output_layer = tf.layers.Dense(\n",
    "            vocab_size, use_bias=False, name=\"output_projection\")\n",
    "\n",
    "        # wrap with ECM internal memory module\n",
    "        memory = encoder_outputs\n",
    "\n",
    "        cell = ECMWrapper(\n",
    "            cell, memory, dec_init_states, attn_num_units, num_units, dtype,\n",
    "            emo_cat_embs, emo_cat, num_emo, emo_int_units)\n",
    "\n",
    "        dec_init_states = cell.initial_state()\n",
    "\n",
    "        # ECM external memory module\n",
    "        emo_output_layer = tf.layers.Dense(\n",
    "            vocab_size, use_bias=False, name=\"emo_output_projection\")\n",
    "\n",
    "        emo_choice_layer = tf.layers.Dense(\n",
    "            1, use_bias=False, name=\"emo_choice_alpha\")\n",
    "\n",
    "        # Decode - for training\n",
    "        # pad the token sequences with SOS (Start of Sentence)\n",
    "        train_outputs = None\n",
    "        if target_ids is not None:\n",
    "            input_ids = tf.pad(target_ids, [[0, 0], [1, 0]], constant_values=0)\n",
    "            embed_inputs = tf.nn.embedding_lookup(embeddings, input_ids)\n",
    "\n",
    "            decoder_outputs, decoder_states = tf.nn.dynamic_rnn(\n",
    "                cell, embed_inputs,\n",
    "                initial_state=dec_init_states,\n",
    "                dtype=dtype,\n",
    "                swap_memory=True)\n",
    "\n",
    "            # logits & final internal memory states\n",
    "            generic_logits = output_layer(decoder_outputs)\n",
    "            emo_ext_logits = emo_output_layer(decoder_outputs)\n",
    "            alphas = tf.nn.sigmoid(emo_choice_layer(decoder_outputs))\n",
    "            int_M_emo = decoder_states.internal_memory\n",
    "\n",
    "            train_outputs = (generic_logits, emo_ext_logits, alphas, int_M_emo)\n",
    "\n",
    "        # Decode - for inference, beam search\n",
    "        infer_outputs = None\n",
    "        if infer_batch_size is not None:\n",
    "            if dec_init_states is None:\n",
    "                dec_init_states = cell.zero_state(infer_batch_size, dtype)\n",
    "\n",
    "            decoder_cell = ECMBeamSearchDecodeCell(\n",
    "                embeddings, cell, dec_init_states, output_layer,\n",
    "                emo_output_layer, emo_choice_layer,\n",
    "                infer_batch_size, dtype, beam_size, vocab_size,\n",
    "                div_gamma=None, div_prob=None)\n",
    "\n",
    "            # namedtuple(logits, ids)\n",
    "            infer_outputs, _ = dynamic_decode(decoder_cell, max_iter)\n",
    "\n",
    "    return cell, train_outputs, infer_outputs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/aaronlai/env3/lib/python3.5/site-packages/ipykernel_launcher.py:30: UserWarning: infer_batch_size not specified, infer output will be 'None'.\n"
     ]
    }
   ],
   "source": [
    "# training\n",
    "target_ids = tf.placeholder(tf.int64, [None, None])\n",
    "emo_cat = tf.placeholder(tf.int64, [None])\n",
    "\n",
    "cell, train_outputs, infer_outputs = build_ECM_decoder(\n",
    "    encoder_outputs, encoder_states, embeddings, num_layers=2, num_units=256,\n",
    "    cell_type=\"LSTM\", num_emo=4, emo_cat=emo_cat, emo_cat_units=32,\n",
    "    emo_int_units=64, state_pass=True, target_ids=target_ids, name=\"ECM_decoder2\")\n",
    "\n",
    "g_logits, e_logits, alphas, M_emo = train_outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([[[0.00099677, 0.00099924, 0.00099844, ..., 0.00100008,\n",
       "          0.00099337, 0.0009939 ],\n",
       "         [0.00099765, 0.00099991, 0.00099775, ..., 0.00099824,\n",
       "          0.00099251, 0.00099288],\n",
       "         [0.00099824, 0.00100007, 0.00099719, ..., 0.00099472,\n",
       "          0.00099212, 0.0009924 ],\n",
       "         [0.00099853, 0.00099985, 0.00099692, ..., 0.00099043,\n",
       "          0.00099219, 0.00099222]],\n",
       " \n",
       "        [[0.0009976 , 0.00099881, 0.00099721, ..., 0.00099865,\n",
       "          0.0009938 , 0.00099695],\n",
       "         [0.00099704, 0.00100132, 0.00100033, ..., 0.00100075,\n",
       "          0.00099185, 0.00099846],\n",
       "         [0.00099719, 0.00100375, 0.00100389, ..., 0.00100124,\n",
       "          0.00099083, 0.00100002],\n",
       "         [0.0009972 , 0.00100644, 0.00100857, ..., 0.00100133,\n",
       "          0.00099003, 0.00100196]]], dtype=float32), array([[[0.49908936],\n",
       "         [0.49997437],\n",
       "         [0.5008866 ],\n",
       "         [0.5017154 ]],\n",
       " \n",
       "        [[0.49884552],\n",
       "         [0.497485  ],\n",
       "         [0.4964107 ],\n",
       "         [0.4961663 ]]], dtype=float32), array([[ 0.00042269,  0.01618492, -0.00378681, -0.01380468,  0.0026418 ,\n",
       "         -0.00878012, -0.01450922,  0.01607385, -0.01348603, -0.0028772 ,\n",
       "          0.0104529 ,  0.00498168,  0.01584366, -0.00512186,  0.0175215 ,\n",
       "         -0.01779045,  0.01453963,  0.01581778,  0.00668984,  0.00899979,\n",
       "         -0.01815362, -0.01509987,  0.00351812, -0.00444507, -0.00970728,\n",
       "          0.00417832,  0.01684892,  0.00215611, -0.01225092,  0.01367479,\n",
       "          0.00420546,  0.00924134, -0.0027324 , -0.0005969 , -0.00130309,\n",
       "          0.00444099,  0.00738562,  0.01504425, -0.00949223, -0.01247715,\n",
       "         -0.00616199, -0.01814302, -0.00527934,  0.0104714 ,  0.01085892,\n",
       "          0.01467417,  0.00498184,  0.01331664,  0.01636752, -0.01348853,\n",
       "          0.00889739,  0.00730877, -0.00560364,  0.011374  ,  0.01460901,\n",
       "         -0.0144643 ,  0.00781834,  0.01690231,  0.00935061, -0.00502634,\n",
       "         -0.01914359, -0.01475261, -0.00963213,  0.00639727],\n",
       "        [-0.00292162, -0.01754373, -0.0163943 , -0.00752104, -0.01576044,\n",
       "          0.0089379 ,  0.01266023, -0.01006428, -0.01126559, -0.00828973,\n",
       "         -0.01242807,  0.00760068, -0.00988987,  0.00318029, -0.00894936,\n",
       "          0.00973476,  0.01594216,  0.01024849, -0.01510452,  0.00766317,\n",
       "         -0.00769749, -0.00775469,  0.00916572, -0.00635401,  0.00665659,\n",
       "         -0.01673188,  0.00963973, -0.00388604, -0.00673485,  0.01642816,\n",
       "          0.01512653,  0.00158559, -0.0143876 , -0.01560979,  0.00699725,\n",
       "         -0.00426637, -0.01536006, -0.00020855,  0.00992646, -0.00431674,\n",
       "         -0.01138919,  0.00459377,  0.0129175 , -0.01009319, -0.0151385 ,\n",
       "         -0.01307235,  0.01729266,  0.01678321,  0.01788699,  0.00795764,\n",
       "         -0.00421176, -0.01739   ,  0.01679986,  0.00119436,  0.00574261,\n",
       "          0.00067792,  0.00448814, -0.00728327,  0.0089157 , -0.01309618,\n",
       "         -0.00084172, -0.01756798,  0.01201619,  0.00115901]],\n",
       "       dtype=float32)]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.05)\n",
    "sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,\n",
    "                                        gpu_options=gpu_options))\n",
    "init = tf.global_variables_initializer()\n",
    "sess.run(init)\n",
    "results = sess.run([tf.nn.softmax(g_logits), alphas, M_emo],\n",
    "                   feed_dict={source_ids: [[3, 3, 3], [4, 5, 6]],\n",
    "                              target_ids: [[8, 8, 8], [8, 10, 12]],\n",
    "                              emo_cat: [0, 1]})\n",
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/aaronlai/env3/lib/python3.5/site-packages/ipykernel_launcher.py:36: UserWarning: target_ids not specified, train_outputs will be 'None'.\n"
     ]
    }
   ],
   "source": [
    "# inference\n",
    "cell, train_outputs, infer_outputs = build_ECM_decoder(\n",
    "    encoder_outputs, encoder_states, embeddings, num_layers=2, num_units=256,\n",
    "    cell_type=\"LSTM\", num_emo=4, emo_cat=emo_cat, emo_cat_units=32,\n",
    "    emo_int_units=64, state_pass=False, infer_batch_size=3,\n",
    "    beam_size=5, max_iter=10, name=\"ECM_infer1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DecoderOutput(logits=array([[[[-7.603438 , -7.599129 , -7.6011157, ..., -7.6059103,\n",
       "          -7.604543 , -7.6052303],\n",
       "         [-7.603438 , -7.599129 , -7.6011157, ..., -7.6059103,\n",
       "          -7.604543 , -7.6052303],\n",
       "         [-7.603438 , -7.599129 , -7.6011157, ..., -7.6059103,\n",
       "          -7.604543 , -7.6052303],\n",
       "         [-7.603438 , -7.599129 , -7.6011157, ..., -7.6059103,\n",
       "          -7.604543 , -7.6052303],\n",
       "         [-7.603438 , -7.599129 , -7.6011157, ..., -7.6059103,\n",
       "          -7.604543 , -7.6052303]],\n",
       "\n",
       "        [[-7.6066465, -7.594799 , -7.599973 , ..., -7.605916 ,\n",
       "          -7.606241 , -7.6054983],\n",
       "         [-7.605527 , -7.595107 , -7.600978 , ..., -7.6061954,\n",
       "          -7.6056223, -7.6059036],\n",
       "         [-7.6048017, -7.594049 , -7.5993166, ..., -7.6071897,\n",
       "          -7.606323 , -7.606474 ],\n",
       "         [-7.605527 , -7.595107 , -7.600978 , ..., -7.6061954,\n",
       "          -7.6056223, -7.6059036],\n",
       "         [-7.6054416, -7.5947514, -7.6001225, ..., -7.6061554,\n",
       "          -7.6055098, -7.605599 ]],\n",
       "\n",
       "        [[-7.6079197, -7.589979 , -7.5998497, ..., -7.6069884,\n",
       "          -7.608154 , -7.607124 ],\n",
       "         [-7.608794 , -7.5912895, -7.602131 , ..., -7.6056027,\n",
       "          -7.607213 , -7.6064296],\n",
       "         [-7.608794 , -7.5912895, -7.602131 , ..., -7.6056027,\n",
       "          -7.607213 , -7.6064296],\n",
       "         [-7.60974  , -7.591046 , -7.6002936, ..., -7.6049814,\n",
       "          -7.6073904, -7.6062226],\n",
       "         [-7.60974  , -7.591046 , -7.6002936, ..., -7.6049814,\n",
       "          -7.6073904, -7.6062226]],\n",
       "\n",
       "        ...,\n",
       "\n",
       "        [[-7.625336 , -7.563318 , -7.5974092, ..., -7.6087832,\n",
       "          -7.6251087, -7.6145024],\n",
       "         [-7.624126 , -7.5627723, -7.5965858, ..., -7.610784 ,\n",
       "          -7.6263266, -7.6150565],\n",
       "         [-7.625336 , -7.563318 , -7.5974092, ..., -7.6087832,\n",
       "          -7.6251087, -7.6145024],\n",
       "         [-7.625336 , -7.563318 , -7.5974092, ..., -7.6087832,\n",
       "          -7.6251087, -7.6145024],\n",
       "         [-7.6245127, -7.564306 , -7.598742 , ..., -7.608622 ,\n",
       "          -7.624615 , -7.614156 ]],\n",
       "\n",
       "        [[-7.626549 , -7.558388 , -7.593225 , ..., -7.6116123,\n",
       "          -7.628853 , -7.6164274],\n",
       "         [-7.6267853, -7.5590725, -7.5950847, ..., -7.6106615,\n",
       "          -7.628305 , -7.616138 ],\n",
       "         [-7.627456 , -7.560353 , -7.597419 , ..., -7.609232 ,\n",
       "          -7.6273546, -7.6155443],\n",
       "         [-7.6267853, -7.5590725, -7.5950847, ..., -7.6106615,\n",
       "          -7.628305 , -7.616138 ],\n",
       "         [-7.6267853, -7.5590725, -7.5950847, ..., -7.6106615,\n",
       "          -7.628305 , -7.616138 ]],\n",
       "\n",
       "        [[-7.628136 , -7.555287 , -7.5927305, ..., -7.6124153,\n",
       "          -7.631188 , -7.617624 ],\n",
       "         [-7.6286607, -7.5563865, -7.594986 , ..., -7.6110663,\n",
       "          -7.6303205, -7.6171246],\n",
       "         [-7.6293225, -7.5560656, -7.593089 , ..., -7.6118665,\n",
       "          -7.6308107, -7.617531 ],\n",
       "         [-7.62864  , -7.557162 , -7.5939837, ..., -7.6111774,\n",
       "          -7.629924 , -7.6163235],\n",
       "         [-7.6274557, -7.5550914, -7.5919814, ..., -7.613942 ,\n",
       "          -7.6321187, -7.61798  ]]],\n",
       "\n",
       "\n",
       "       [[[-7.6022053, -7.601361 , -7.6024394, ..., -7.6053767,\n",
       "          -7.604671 , -7.603977 ],\n",
       "         [-7.6022053, -7.601361 , -7.6024394, ..., -7.6053767,\n",
       "          -7.604671 , -7.603977 ],\n",
       "         [-7.6022053, -7.601361 , -7.6024394, ..., -7.6053767,\n",
       "          -7.604671 , -7.603977 ],\n",
       "         [-7.6022053, -7.601361 , -7.6024394, ..., -7.6053767,\n",
       "          -7.604671 , -7.603977 ],\n",
       "         [-7.6022053, -7.601361 , -7.6024394, ..., -7.6053767,\n",
       "          -7.604671 , -7.603977 ]],\n",
       "\n",
       "        [[-7.603071 , -7.6011057, -7.602374 , ..., -7.6052694,\n",
       "          -7.605212 , -7.601905 ],\n",
       "         [-7.6030946, -7.6013656, -7.6026506, ..., -7.6045794,\n",
       "          -7.606336 , -7.602065 ],\n",
       "         [-7.603071 , -7.6011057, -7.602374 , ..., -7.6052694,\n",
       "          -7.605212 , -7.601905 ],\n",
       "         [-7.6027193, -7.600962 , -7.602138 , ..., -7.6062756,\n",
       "          -7.606288 , -7.6029186],\n",
       "         [-7.603416 , -7.6010685, -7.6034427, ..., -7.6049213,\n",
       "          -7.604841 , -7.6022162]],\n",
       "\n",
       "        [[-7.60555  , -7.6021385, -7.602995 , ..., -7.6042285,\n",
       "          -7.6054015, -7.5987473],\n",
       "         [-7.606103 , -7.6020594, -7.604679 , ..., -7.6037035,\n",
       "          -7.604823 , -7.59908  ],\n",
       "         [-7.6051846, -7.602059 , -7.602821 , ..., -7.605633 ,\n",
       "          -7.6068654, -7.600179 ],\n",
       "         [-7.604868 , -7.602155 , -7.6037188, ..., -7.604468 ,\n",
       "          -7.6059403, -7.599351 ],\n",
       "         [-7.60555  , -7.6021385, -7.602995 , ..., -7.6042285,\n",
       "          -7.6054015, -7.5987473]],\n",
       "\n",
       "        ...,\n",
       "\n",
       "        [[-7.6194434, -7.611935 , -7.6121254, ..., -7.5986824,\n",
       "          -7.6088285, -7.590448 ],\n",
       "         [-7.621096 , -7.6121244, -7.610786 , ..., -7.5974407,\n",
       "          -7.6081057, -7.5901256],\n",
       "         [-7.6205873, -7.6121798, -7.611241 , ..., -7.5977883,\n",
       "          -7.6083355, -7.590326 ],\n",
       "         [-7.619691 , -7.611676 , -7.6108475, ..., -7.598347 ,\n",
       "          -7.6091013, -7.5895486],\n",
       "         [-7.6205873, -7.6121798, -7.611241 , ..., -7.5977883,\n",
       "          -7.6083355, -7.590326 ]],\n",
       "\n",
       "        [[-7.622116 , -7.6133394, -7.6117764, ..., -7.5974174,\n",
       "          -7.6089334, -7.5906363],\n",
       "         [-7.620977 , -7.6130967, -7.612664 , ..., -7.598315 ,\n",
       "          -7.6094193, -7.5907593],\n",
       "         [-7.622621 , -7.613284 , -7.611319 , ..., -7.5970736,\n",
       "          -7.6087055, -7.590439 ],\n",
       "         [-7.620977 , -7.6130967, -7.612664 , ..., -7.598315 ,\n",
       "          -7.6094193, -7.5907593],\n",
       "         [-7.622734 , -7.6130724, -7.6111174, ..., -7.5970564,\n",
       "          -7.6086874, -7.5902767]],\n",
       "\n",
       "        [[-7.6249914, -7.614518 , -7.610626 , ..., -7.595913 ,\n",
       "          -7.6087914, -7.590852 ],\n",
       "         [-7.625603 , -7.6142497, -7.6099615, ..., -7.595555 ,\n",
       "          -7.608554 , -7.590495 ],\n",
       "         [-7.625017 , -7.6140575, -7.6104207, ..., -7.59604  ,\n",
       "          -7.6088514, -7.590577 ],\n",
       "         [-7.625107 , -7.6143093, -7.610426 , ..., -7.5958986,\n",
       "          -7.608775 , -7.590692 ],\n",
       "         [-7.623888 , -7.6138263, -7.611322 , ..., -7.596936 ,\n",
       "          -7.6093297, -7.5906935]]],\n",
       "\n",
       "\n",
       "       [[[-7.602971 , -7.6004305, -7.5992026, ..., -7.6057777,\n",
       "          -7.603143 , -7.6047397],\n",
       "         [-7.602971 , -7.6004305, -7.5992026, ..., -7.6057777,\n",
       "          -7.603143 , -7.6047397],\n",
       "         [-7.602971 , -7.6004305, -7.5992026, ..., -7.6057777,\n",
       "          -7.603143 , -7.6047397],\n",
       "         [-7.602971 , -7.6004305, -7.5992026, ..., -7.6057777,\n",
       "          -7.603143 , -7.6047397],\n",
       "         [-7.602971 , -7.6004305, -7.5992026, ..., -7.6057777,\n",
       "          -7.603143 , -7.6047397]],\n",
       "\n",
       "        [[-7.6039686, -7.5992556, -7.595621 , ..., -7.6057568,\n",
       "          -7.603115 , -7.6037626],\n",
       "         [-7.6048875, -7.599878 , -7.596942 , ..., -7.6055284,\n",
       "          -7.6018453, -7.6031384],\n",
       "         [-7.6039686, -7.5992556, -7.595621 , ..., -7.6057568,\n",
       "          -7.603115 , -7.6037626],\n",
       "         [-7.603637 , -7.5984964, -7.5949316, ..., -7.606773 ,\n",
       "          -7.602876 , -7.604298 ],\n",
       "         [-7.604011 , -7.599472 , -7.595298 , ..., -7.606364 ,\n",
       "          -7.6020045, -7.604344 ]],\n",
       "\n",
       "        [[-7.606044 , -7.599559 , -7.5934353, ..., -7.6049256,\n",
       "          -7.601792 , -7.6016603],\n",
       "         [-7.606044 , -7.599559 , -7.5934353, ..., -7.6049256,\n",
       "          -7.601792 , -7.6016603],\n",
       "         [-7.605551 , -7.598506 , -7.592332 , ..., -7.6060553,\n",
       "          -7.602985 , -7.602515 ],\n",
       "         [-7.6055737, -7.598447 , -7.592437 , ..., -7.606386 ,\n",
       "          -7.6013107, -7.6023173],\n",
       "         [-7.606044 , -7.599559 , -7.5934353, ..., -7.6049256,\n",
       "          -7.601792 , -7.6016603]],\n",
       "\n",
       "        ...,\n",
       "\n",
       "        [[-7.6119466, -7.607239 , -7.589416 , ..., -7.6003513,\n",
       "          -7.6010756, -7.5937476],\n",
       "         [-7.611604 , -7.606883 , -7.589822 , ..., -7.6002254,\n",
       "          -7.599988 , -7.593976 ],\n",
       "         [-7.611604 , -7.606883 , -7.589822 , ..., -7.6002254,\n",
       "          -7.599988 , -7.593976 ],\n",
       "         [-7.6126676, -7.607599 , -7.5904846, ..., -7.600297 ,\n",
       "          -7.601719 , -7.594397 ],\n",
       "         [-7.6126676, -7.607599 , -7.5904846, ..., -7.600297 ,\n",
       "          -7.601719 , -7.594397 ]],\n",
       "\n",
       "        [[-7.6130247, -7.6085196, -7.5903687, ..., -7.5993333,\n",
       "          -7.6018853, -7.593522 ],\n",
       "         [-7.612967 , -7.6085277, -7.590172 , ..., -7.599313 ,\n",
       "          -7.6018023, -7.593417 ],\n",
       "         [-7.6138406, -7.608591 , -7.5912495, ..., -7.598647 ,\n",
       "          -7.6019497, -7.5931973],\n",
       "         [-7.6122994, -7.608014 , -7.590353 , ..., -7.5993996,\n",
       "          -7.600994 , -7.593799 ],\n",
       "         [-7.612967 , -7.6085277, -7.590172 , ..., -7.599313 ,\n",
       "          -7.6018023, -7.593417 ]],\n",
       "\n",
       "        [[-7.614052 , -7.6095753, -7.5912824, ..., -7.5982375,\n",
       "          -7.602423 , -7.593064 ],\n",
       "         [-7.6131635, -7.6089196, -7.591197 , ..., -7.5984507,\n",
       "          -7.6018343, -7.593472 ],\n",
       "         [-7.613207 , -7.608748 , -7.591107 , ..., -7.5977974,\n",
       "          -7.601374 , -7.592779 ],\n",
       "         [-7.6149   , -7.609424 , -7.592272 , ..., -7.597641 ,\n",
       "          -7.602467 , -7.5926185],\n",
       "         [-7.6148496, -7.609589 , -7.5923576, ..., -7.5982976,\n",
       "          -7.6029277, -7.5933146]]]], dtype=float32), ids=array([[[ 818,  205,  818,  173,  464],\n",
       "        [ 818,  818,  818,  818,  818],\n",
       "        [ 818,  818,  818,  818,  818],\n",
       "        [ 205,  818,  818,  205,  205],\n",
       "        [ 205,  205,  205,  205,  205],\n",
       "        [ 205,  205,  205,  205,  205],\n",
       "        [ 205,  205,  205,  205,  818],\n",
       "        [ 205,  205,  205,  205,  205],\n",
       "        [ 205,  205,  205,  205,  205],\n",
       "        [   1,  205,  205,  205,  205],\n",
       "        [   1,  205,  205,  918,    1]],\n",
       "\n",
       "       [[ 323,  907,  412,  323,  419],\n",
       "        [ 323,  323,  323,  323,  323],\n",
       "        [1645, 1645, 1065, 1645, 1645],\n",
       "        [1645, 1645, 1645, 1065, 1645],\n",
       "        [1645, 1645, 1645, 1645, 1645],\n",
       "        [1645, 1645, 1645, 1645, 1645],\n",
       "        [1645, 1645, 1645, 1645, 1007],\n",
       "        [1645, 1645, 1645, 1007, 1007],\n",
       "        [1007, 1007, 1645, 1645, 1645],\n",
       "        [1007, 1645, 1007, 1645, 1007],\n",
       "        [1007, 1007, 1007, 1645, 1007]],\n",
       "\n",
       "       [[ 419,  739,  739,  739,  739],\n",
       "        [  35,  419,   35,   35,  739],\n",
       "        [  35,   35, 1900,   35,   35],\n",
       "        [  35, 1900, 1900, 1900,  772],\n",
       "        [1900, 1900, 1900,   35, 1900],\n",
       "        [1900, 1900,   35, 1900,  772],\n",
       "        [ 772, 1900, 1500, 1900, 1900],\n",
       "        [1500, 1500, 1500, 1500, 1500],\n",
       "        [1500, 1500, 1500, 1500, 1500],\n",
       "        [1197, 1500, 1500, 1500, 1500],\n",
       "        [1197, 1500, 1197, 1197, 1197]]], dtype=int32))"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "init = tf.global_variables_initializer()\n",
    "sess.run(init)\n",
    "inf_results = sess.run(infer_outputs,\n",
    "                       feed_dict={\n",
    "                           source_ids: [[3, 3, 3], [4, 5, 6], [10, 10, 10]],\n",
    "                           emo_cat: [1, 2, 3],\n",
    "                       })\n",
    "inf_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[ 818,  205,  818,  173,  464],\n",
       "        [ 818,  818,  818,  818,  818],\n",
       "        [ 818,  818,  818,  818,  818],\n",
       "        [ 205,  818,  818,  205,  205],\n",
       "        [ 205,  205,  205,  205,  205],\n",
       "        [ 205,  205,  205,  205,  205],\n",
       "        [ 205,  205,  205,  205,  818],\n",
       "        [ 205,  205,  205,  205,  205],\n",
       "        [ 205,  205,  205,  205,  205],\n",
       "        [   1,  205,  205,  205,  205],\n",
       "        [   1,  205,  205,  918,    1]],\n",
       "\n",
       "       [[ 323,  907,  412,  323,  419],\n",
       "        [ 323,  323,  323,  323,  323],\n",
       "        [1645, 1645, 1065, 1645, 1645],\n",
       "        [1645, 1645, 1645, 1065, 1645],\n",
       "        [1645, 1645, 1645, 1645, 1645],\n",
       "        [1645, 1645, 1645, 1645, 1645],\n",
       "        [1645, 1645, 1645, 1645, 1007],\n",
       "        [1645, 1645, 1645, 1007, 1007],\n",
       "        [1007, 1007, 1645, 1645, 1645],\n",
       "        [1007, 1645, 1007, 1645, 1007],\n",
       "        [1007, 1007, 1007, 1645, 1007]],\n",
       "\n",
       "       [[ 419,  739,  739,  739,  739],\n",
       "        [  35,  419,   35,   35,  739],\n",
       "        [  35,   35, 1900,   35,   35],\n",
       "        [  35, 1900, 1900, 1900,  772],\n",
       "        [1900, 1900, 1900,   35, 1900],\n",
       "        [1900, 1900,   35, 1900,  772],\n",
       "        [ 772, 1900, 1500, 1900, 1900],\n",
       "        [1500, 1500, 1500, 1500, 1500],\n",
       "        [1500, 1500, 1500, 1500, 1500],\n",
       "        [1197, 1500, 1500, 1500, 1500],\n",
       "        [1197, 1500, 1197, 1197, 1197]]], dtype=int32)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inf_results.ids"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training and Saving"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute ECM loss\n",
    "def compute_ECM_loss(source_ids, target_ids, sequence_mask, choice_qs,\n",
    "                     embeddings, enc_num_layers, enc_num_units, enc_cell_type,\n",
    "                     enc_bidir, dec_num_layers, dec_num_units, dec_cell_type,\n",
    "                     state_pass, num_emo, emo_cat, emo_cat_units,\n",
    "                     emo_int_units, infer_batch_size, beam_size=None,\n",
    "                     max_iter=20, attn_num_units=128, l2_regularize=None,\n",
    "                     name=\"ECM\"):\n",
    "    \"\"\"\n",
    "    Creates an ECM model and returns CE loss plus regularization terms.\n",
    "        choice_qs: [batch_size, max_time], true choice btw generic/emo words\n",
    "        emo_cat: [batch_size], emotion categories of each target sequence\n",
    "\n",
    "    Returns\n",
    "        CE: cross entropy, used to compute perplexity\n",
    "        total_loss: objective of the entire model\n",
    "        train_outs: (cell, log_probs, alphas, final_int_mem_states)\n",
    "            alphas - predicted choices\n",
    "        infer_outputs: namedtuple(logits, ids), [batch_size, max_time, d]\n",
    "    \"\"\"\n",
    "    with tf.name_scope(name):\n",
    "        # build encoder\n",
    "        encoder_outputs, encoder_states = build_encoder(\n",
    "            embeddings, source_ids, enc_num_layers, enc_num_units,\n",
    "            enc_cell_type, bidir=enc_bidir, name=\"%s_encoder\" % name)\n",
    "\n",
    "        # build decoder: logits, [batch_size, max_time, vocab_size]\n",
    "        cell, train_outputs, infer_outputs = build_ECM_decoder(\n",
    "            encoder_outputs, encoder_states, embeddings,\n",
    "            dec_num_layers, dec_num_units, dec_cell_type, \n",
    "            num_emo, emo_cat, emo_cat_units, emo_int_units,\n",
    "            state_pass, infer_batch_size, attn_num_units,\n",
    "            target_ids, beam_size, max_iter,\n",
    "            name=\"%s_decoder\" % name)\n",
    "\n",
    "        g_logits, e_logits, alphas, int_M_emo = train_outputs\n",
    "        g_probs = tf.nn.softmax(g_logits) * (1 - alphas)\n",
    "        e_probs = tf.nn.softmax(e_logits) * alphas\n",
    "        train_log_probs = tf.log(g_probs + e_probs)\n",
    "\n",
    "        with tf.name_scope('loss'):\n",
    "            final_ids = tf.pad(target_ids, [[0, 0], [0, 1]], constant_values=1)\n",
    "            alphas = tf.squeeze(alphas, axis=-1)\n",
    "            choice_qs = tf.pad(choice_qs, [[0, 0], [0, 1]], constant_values=0)\n",
    "\n",
    "            # compute losses\n",
    "            g_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "                logits=g_logits, labels=final_ids) - tf.log(1 - alphas)\n",
    "\n",
    "            e_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "                logits=e_logits, labels=final_ids) - tf.log(alphas)\n",
    "\n",
    "            losses = g_losses * (1 - choice_qs) + e_losses * choice_qs\n",
    "\n",
    "            # alpha and internal memory regularizations\n",
    "            alpha_reg = tf.reduce_mean(choice_qs * -tf.log(alphas))\n",
    "            int_mem_reg = tf.reduce_mean(tf.norm(int_M_emo, axis=1))\n",
    "\n",
    "            losses = tf.boolean_mask(losses[:, :-1], sequence_mask)\n",
    "            reduced_loss = tf.reduce_mean(losses) + alpha_reg + int_mem_reg\n",
    "\n",
    "            # prepare for perplexity computations\n",
    "            CE = tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "                logits=train_log_probs, labels=final_ids)\n",
    "            CE = tf.boolean_mask(CE[:, :-1], sequence_mask)\n",
    "            CE = tf.reduce_sum(CE)\n",
    "\n",
    "            train_outs = (cell, train_log_probs, alphas, int_M_emo)\n",
    "            if l2_regularize is None:\n",
    "                return CE, reduced_loss, train_outs, infer_outputs\n",
    "            else:\n",
    "                l2_loss = tf.add_n([tf.nn.l2_loss(v)\n",
    "                                    for v in tf.trainable_variables()\n",
    "                                    if not('bias' in v.name)])\n",
    "\n",
    "                total_loss = reduced_loss + l2_regularize * l2_loss\n",
    "                return CE, total_loss, train_outs, infer_outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_perplexity(sess, CE, mask, feed_dict):\n",
    "    \"\"\"\n",
    "    Compute perplexity for a batch of data\n",
    "    \"\"\"\n",
    "    CE_words = sess.run(CE, feed_dict=feed_dict)\n",
    "    N_words = np.sum(mask)\n",
    "    return np.exp(CE_words / N_words)\n",
    "\n",
    "\n",
    "def loadfile(filename, is_source, max_length):\n",
    "    \"\"\"\n",
    "    Load and clean data\n",
    "    \"\"\"\n",
    "    def clean(row):\n",
    "        row = np.array(row.split(), dtype=np.int32)\n",
    "        leng = len(row)\n",
    "        if leng < max_length:\n",
    "            if is_source:\n",
    "                # represents constant zero padding\n",
    "                pads = -np.ones(max_length - leng, dtype=np.int32)\n",
    "                row = np.concatenate((pads, row))\n",
    "            else:\n",
    "                # represents EOS token\n",
    "                pads = -2 * np.ones(max_length - leng, dtype=np.int32)\n",
    "                row = np.concatenate((row, pads))\n",
    "        elif leng > max_length:\n",
    "            row = row[:max_length]\n",
    "        return row\n",
    "\n",
    "    df = pd.read_csv(filename, header=None, index_col=None)\n",
    "    data = np.array(df[0].apply(lambda t: clean(t)).tolist(), dtype=np.int32)\n",
    "    return data\n",
    "\n",
    "# saving and load\n",
    "def load(saver, sess, logdir):\n",
    "    print(\"Trying to restore saved checkpoints from {} ...\".format(logdir),\n",
    "          end=\"\")\n",
    "\n",
    "    ckpt = tf.train.get_checkpoint_state(logdir)\n",
    "    if ckpt:\n",
    "        print(\"  Checkpoint found: {}\".format(ckpt.model_checkpoint_path))\n",
    "        global_step = int(ckpt.model_checkpoint_path\n",
    "                          .split('/')[-1]\n",
    "                          .split('-')[-1])\n",
    "        print(\"  Global step was: {}\".format(global_step))\n",
    "        print(\"  Restoring...\", end=\"\")\n",
    "        saver.restore(sess, ckpt.model_checkpoint_path)\n",
    "        print(\" Done.\")\n",
    "        return global_step\n",
    "    else:\n",
    "        print(\" No checkpoint found.\")\n",
    "        return None\n",
    "\n",
    "\n",
    "def save(saver, sess, logdir, step):\n",
    "    model_name = 'model.ckpt'\n",
    "    checkpoint_path = os.path.join(logdir, model_name)\n",
    "    print('Storing checkpoint to {} ...'.format(logdir), end=\"\")\n",
    "    sys.stdout.flush()\n",
    "\n",
    "    if not os.path.exists(logdir):\n",
    "        os.makedirs(logdir)\n",
    "\n",
    "    saver.save(sess, checkpoint_path, global_step=step)\n",
    "    print(' Done.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# YAML configuration\n",
    "def get_ECM_config(config):\n",
    "    enc_num_layers = config[\"encoder\"][\"num_layers\"]\n",
    "    enc_num_units = config[\"encoder\"][\"num_units\"]\n",
    "    enc_cell_type = config[\"encoder\"][\"cell_type\"]\n",
    "    enc_bidir = config[\"encoder\"][\"bidirectional\"]\n",
    "\n",
    "    dec_num_layers = config[\"decoder\"][\"num_layers\"]\n",
    "    dec_num_units = config[\"decoder\"][\"num_units\"]\n",
    "    dec_cell_type = config[\"decoder\"][\"cell_type\"]\n",
    "    state_pass = config[\"decoder\"][\"state_pass\"]\n",
    "\n",
    "    num_emo = config[\"decoder\"][\"num_emotions\"]\n",
    "    emo_cat_units = config[\"decoder\"][\"emo_cat_units\"]\n",
    "    emo_int_units = config[\"decoder\"][\"emo_int_units\"]\n",
    "\n",
    "    infer_batch_size = config[\"inference\"][\"infer_batch_size\"]\n",
    "    beam_size = config[\"inference\"][\"beam_size\"]\n",
    "    max_iter = config[\"inference\"][\"max_length\"]\n",
    "    attn_num_units = config[\"decoder\"][\"attn_num_units\"]\n",
    "    l2_regularize = config[\"training\"][\"l2_regularize\"]\n",
    "\n",
    "    return (enc_num_layers, enc_num_units, enc_cell_type, enc_bidir,\n",
    "            dec_num_layers, dec_num_units, dec_cell_type, state_pass,\n",
    "            num_emo, emo_cat_units, emo_int_units, infer_batch_size,\n",
    "            beam_size, max_iter, attn_num_units, l2_regularize)\n",
    "\n",
    "\n",
    "def get_ECM_training_config(config):\n",
    "    train_config = config[\"training\"]\n",
    "    logdir = train_config[\"logdir\"]\n",
    "    restore_from = train_config[\"restore_from\"]\n",
    "\n",
    "    learning_rate = train_config[\"learning_rate\"]\n",
    "    gpu_fraction = train_config[\"gpu_fraction\"]\n",
    "    max_checkpoints = train_config[\"max_checkpoints\"]\n",
    "    train_steps = train_config[\"train_steps\"]\n",
    "    batch_size = train_config[\"batch_size\"]\n",
    "    print_every = train_config[\"print_every\"]\n",
    "    checkpoint_every = train_config[\"checkpoint_every\"]\n",
    "\n",
    "    s_filename = train_config[\"train_source_file\"]\n",
    "    t_filename = train_config[\"train_target_file\"]\n",
    "    q_filename = train_config[\"train_choice_file\"]\n",
    "    c_filename = train_config[\"train_category_file\"]\n",
    "\n",
    "    s_max_leng = train_config[\"source_max_length\"]\n",
    "    t_max_leng = train_config[\"target_max_length\"]\n",
    "\n",
    "    dev_s_filename = train_config[\"dev_source_file\"]\n",
    "    dev_t_filename = train_config[\"dev_target_file\"]\n",
    "    dev_q_filename = train_config[\"dev_choice_file\"]\n",
    "    dev_c_filename = train_config[\"dev_category_file\"]\n",
    "\n",
    "    loss_fig = train_config[\"loss_fig\"]\n",
    "    perp_fig = train_config[\"perplexity_fig\"]\n",
    "\n",
    "    return (logdir, restore_from, learning_rate, gpu_fraction, max_checkpoints,\n",
    "            train_steps, batch_size, print_every, checkpoint_every,\n",
    "            s_filename, t_filename, q_filename, c_filename,\n",
    "            s_max_leng, t_max_leng, dev_s_filename, dev_t_filename,\n",
    "            dev_q_filename, dev_c_filename, loss_fig, perp_fig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "import yaml\n",
    "\n",
    "\n",
    "config_details = {\n",
    "    \"Name\": \"EmotionChattingMachine\",\n",
    "    \"embeddings\": {\n",
    "        \"vocab_size\": 1000,\n",
    "        \"embed_size\": 128,\n",
    "    },\n",
    "    \"encoder\": {\n",
    "        \"num_layers\": 2,\n",
    "        \"num_units\": 256,\n",
    "        \"cell_type\": \"LSTM\",\n",
    "        \"bidirectional\": True,\n",
    "    },\n",
    "    \"decoder\": {\n",
    "        \"num_layers\": 2,\n",
    "        \"num_units\": 256,\n",
    "        \"cell_type\": \"LSTM\",\n",
    "        \"state_pass\": True,\n",
    "        \"wrapper\": \"ECM\",\n",
    "        \"attn_num_units\": 128,\n",
    "        \"num_emotions\": 4,\n",
    "        \"emo_cat_units\": 32,\n",
    "        \"emo_int_units\": 64,\n",
    "    },\n",
    "    \"inference\": {\n",
    "        \"infer_batch_size\": 15,\n",
    "        \"type\": \"beam_search\",\n",
    "        \"beam_size\": 5,\n",
    "        \"max_length\": 20,\n",
    "        \"infer_source_file\": \"./example/dev_source.txt\",\n",
    "        \"infer_category_file\": \"./example/dev_category.txt\",\n",
    "        \"infer_source_max_length\": 25,\n",
    "        \"output_path\": \"./ECM_prediction.txt\",\n",
    "        \"choice_path\": \"./choice_pred.txt\",\n",
    "    },\n",
    "    \"training\": {\n",
    "        \"l2_regularize\": None,\n",
    "        \"logdir\": \"./log_ECM/\",\n",
    "        \"restore_from\": \"./log_ECM/\",\n",
    "        \"learning_rate\": 1e-3,\n",
    "        \"batch_size\": 64,\n",
    "        \"gpu_fraction\": 0.05,\n",
    "        \"max_checkpoints\": 10000,\n",
    "        \"train_steps\": 5000,\n",
    "        \"print_every\": 20,\n",
    "        \"checkpoint_every\": 1000,\n",
    "        \"train_source_file\": \"./example/train_source.txt\",\n",
    "        \"train_target_file\": \"./example/train_target.txt\",\n",
    "        \"train_choice_file\": \"./example/train_choice.txt\",\n",
    "        \"train_category_file\": \"./example/train_category.txt\",\n",
    "        \"dev_source_file\": \"./example/dev_source.txt\",\n",
    "        \"dev_target_file\": \"./example/dev_target.txt\",\n",
    "        \"dev_choice_file\": \"./example/dev_choice.txt\",\n",
    "        \"dev_category_file\": \"./example/dev_category.txt\",\n",
    "        \"source_max_length\": 25,\n",
    "        \"target_max_length\": 25,\n",
    "        \"loss_fig\": \"./ECM_training_loss_over_time\",\n",
    "        \"perplexity_fig\": \"./ECM_perplexity_over_time\",\n",
    "    }\n",
    "}\n",
    "\n",
    "with open('./configs/config_ECM.yaml', \"w\") as f:\n",
    "    yaml.dump({\"configuration\": config_details}, f, default_flow_style=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "with open('./configs/config_ECM.yaml') as f:\n",
    "    # use safe_load instead load\n",
    "    config = yaml.safe_load(f)\n",
    "\n",
    "config = config[\"configuration\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Construct model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initializing embeddings ...\n",
      "\tDone.\n",
      "Building model architecture ...\n",
      "\tDone.\n"
     ]
    }
   ],
   "source": [
    "# loading configurations\n",
    "name = config[\"Name\"]\n",
    "\n",
    "# Construct or load embeddings\n",
    "print(\"Initializing embeddings ...\")\n",
    "vocab_size = config[\"embeddings\"][\"vocab_size\"]\n",
    "embed_size = config[\"embeddings\"][\"embed_size\"]\n",
    "embeddings = init_embeddings(vocab_size, embed_size, name=name)\n",
    "print(\"\\tDone.\")\n",
    "\n",
    "# Build the model and compute losses\n",
    "source_ids = tf.placeholder(tf.int32, [None, None], name=\"source\")\n",
    "target_ids = tf.placeholder(tf.int32, [None, None], name=\"target\")\n",
    "sequence_mask = tf.placeholder(tf.bool, [None, None], name=\"mask\")\n",
    "choice_qs = tf.placeholder(tf.float32, [None, None], name=\"choice\")\n",
    "emo_cat = tf.placeholder(tf.int32, [None], name=\"emotion_category\")\n",
    "\n",
    "(enc_num_layers, enc_num_units, enc_cell_type, enc_bidir,\n",
    " dec_num_layers, dec_num_units, dec_cell_type, state_pass,\n",
    " num_emo, emo_cat_units, emo_int_units,\n",
    " infer_batch_size, beam_size, max_iter,\n",
    " attn_num_units, l2_regularize) = get_ECM_config(config)\n",
    "\n",
    "print(\"Building model architecture ...\")\n",
    "\n",
    "CE, loss, train_outs, infer_outputs = compute_ECM_loss(\n",
    "    source_ids, target_ids, sequence_mask, choice_qs, embeddings,\n",
    "    enc_num_layers, enc_num_units, enc_cell_type, enc_bidir,\n",
    "    dec_num_layers, dec_num_units, dec_cell_type, state_pass,\n",
    "    num_emo, emo_cat, emo_cat_units, emo_int_units, infer_batch_size,\n",
    "    beam_size, max_iter, attn_num_units, l2_regularize, name)\n",
    "print(\"\\tDone.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/aaronlai/env3/lib/python3.5/site-packages/tensorflow/python/ops/gradients_impl.py:96: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n",
      "  \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trying to restore saved checkpoints from ./log_ECM/ ... No checkpoint found.\n"
     ]
    }
   ],
   "source": [
    "# Preparing for training\n",
    "(logdir, restore_from, learning_rate, gpu_fraction, max_checkpoints,\n",
    " train_steps, batch_size, print_every, checkpoint_every, s_filename,\n",
    " t_filename, q_filename, c_filename, s_max_leng, t_max_leng,\n",
    " dev_s_filename, dev_t_filename, dev_q_filename, dev_c_filename,\n",
    " loss_fig, perp_fig) = get_ECM_training_config(config)\n",
    "\n",
    "# Even if we restored the model, we will treat it as new training\n",
    "# if the trained model is written into an arbitrary location.\n",
    "is_overwritten_training = logdir != restore_from\n",
    "\n",
    "optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,\n",
    "                                   epsilon=1e-4)\n",
    "trainable = tf.trainable_variables()\n",
    "optim = optimizer.minimize(loss, var_list=trainable)\n",
    "\n",
    "# Set up session\n",
    "gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)\n",
    "sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,\n",
    "                                        gpu_options=gpu_options))\n",
    "init = tf.global_variables_initializer()\n",
    "sess.run(init)\n",
    "\n",
    "# Saver for storing checkpoints of the model.\n",
    "saver = tf.train.Saver(var_list=tf.trainable_variables(),\n",
    "                       max_to_keep=max_checkpoints)\n",
    "\n",
    "try:\n",
    "    saved_global_step = load(saver, sess, restore_from)\n",
    "    if is_overwritten_training or saved_global_step is None:\n",
    "        # The first training step will be saved_global_step + 1,\n",
    "        # therefore we put -1 here for new or overwritten trainings.\n",
    "        saved_global_step = -1\n",
    "\n",
    "except Exception:\n",
    "    print(\"Something went wrong while restoring checkpoint. \"\n",
    "          \"Training is terminated to avoid the overwriting.\")\n",
    "    raise"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading data ...\n",
      "\tDone.\n"
     ]
    }
   ],
   "source": [
    "# Load data\n",
    "print(\"Loading data ...\")\n",
    "\n",
    "# id_0, id_1, id_2 preserved for SOS, EOS, constant zero padding\n",
    "embed_shift = 3\n",
    "\n",
    "source_data = loadfile(s_filename, is_source=True,\n",
    "                       max_length=s_max_leng) + embed_shift\n",
    "target_data = loadfile(t_filename, is_source=False,\n",
    "                       max_length=t_max_leng) + embed_shift\n",
    "\n",
    "choice_data = loadfile(q_filename, is_source=False, max_length=t_max_leng)\n",
    "choice_data[choice_data < 0] = 0\n",
    "choice_data = choice_data.astype(np.float32)\n",
    "\n",
    "category_data = pd.read_csv(\n",
    "    c_filename, header=None, index_col=None, dtype=int)[0].values\n",
    "\n",
    "masks = (target_data >= embed_shift)\n",
    "masks = np.append(np.ones([len(masks), 1], dtype=bool), masks, axis=1)\n",
    "masks = masks[:, :-1]\n",
    "\n",
    "n_data = len(source_data)\n",
    "\n",
    "dev_source_data = None\n",
    "if dev_s_filename is not None:\n",
    "    dev_source_data = loadfile(dev_s_filename, is_source=True,\n",
    "                               max_length=s_max_leng) + embed_shift\n",
    "    dev_target_data = loadfile(dev_t_filename, is_source=False,\n",
    "                               max_length=t_max_leng) + embed_shift\n",
    "\n",
    "    dev_choice_data = loadfile(dev_q_filename, is_source=False,\n",
    "                               max_length=t_max_leng)\n",
    "    dev_choice_data[dev_choice_data < 0] = 0\n",
    "    dev_choice_data = dev_choice_data.astype(np.float32)\n",
    "\n",
    "    dev_category_data = pd.read_csv(\n",
    "        dev_c_filename, header=None, index_col=None, dtype=int)[0].values\n",
    "\n",
    "    dev_masks = (dev_target_data >= embed_shift)\n",
    "    dev_masks = np.append(\n",
    "        np.ones([len(dev_masks), 1], dtype=bool), dev_masks, axis=1)\n",
    "    dev_masks = dev_masks[:, :-1]\n",
    "\n",
    "print(\"\\tDone.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Start training ...\n",
      "step 0, loss = 7.691133, perp: 1002.357, dev_prep: 1002.630, (0.678 sec/step)\n",
      "Storing checkpoint to ./log_ECM/ ... Done.\n",
      "step 20, loss = 7.485613, perp: 953.184, dev_prep: 954.556, (0.182 sec/step)\n",
      "step 40, loss = 7.345282, perp: 840.894, dev_prep: 815.441, (0.191 sec/step)\n",
      "step 60, loss = 7.058381, perp: 721.055, dev_prep: 715.302, (0.186 sec/step)\n",
      "step 80, loss = 6.743235, perp: 577.174, dev_prep: 588.318, (0.184 sec/step)\n",
      "step 100, loss = 6.396844, perp: 455.955, dev_prep: 468.569, (0.191 sec/step)\n",
      "step 120, loss = 6.246303, perp: 413.593, dev_prep: 407.883, (0.193 sec/step)\n",
      "step 140, loss = 6.178908, perp: 394.116, dev_prep: 400.349, (0.190 sec/step)\n",
      "step 160, loss = 6.179094, perp: 398.020, dev_prep: 403.934, (0.186 sec/step)\n",
      "step 180, loss = 6.168450, perp: 381.776, dev_prep: 392.989, (0.182 sec/step)\n",
      "step 200, loss = 6.157249, perp: 389.813, dev_prep: 387.887, (0.191 sec/step)\n",
      "step 220, loss = 6.105398, perp: 358.223, dev_prep: 399.134, (0.184 sec/step)\n",
      "step 240, loss = 6.263700, perp: 427.244, dev_prep: 383.074, (0.195 sec/step)\n",
      "step 260, loss = 6.061069, perp: 337.857, dev_prep: 407.689, (0.179 sec/step)\n",
      "step 280, loss = 6.144888, perp: 381.682, dev_prep: 360.577, (0.181 sec/step)\n",
      "step 300, loss = 6.065662, perp: 365.256, dev_prep: 378.623, (0.190 sec/step)\n",
      "step 320, loss = 6.043445, perp: 356.972, dev_prep: 362.546, (0.180 sec/step)\n",
      "step 340, loss = 6.077257, perp: 359.708, dev_prep: 350.801, (0.186 sec/step)\n",
      "step 360, loss = 5.995629, perp: 333.147, dev_prep: 379.990, (0.194 sec/step)\n",
      "step 380, loss = 6.095203, perp: 370.319, dev_prep: 412.752, (0.175 sec/step)\n",
      "step 400, loss = 6.098328, perp: 364.636, dev_prep: 364.223, (0.182 sec/step)\n",
      "step 420, loss = 6.056013, perp: 351.322, dev_prep: 360.352, (0.188 sec/step)\n",
      "step 440, loss = 6.178652, perp: 389.235, dev_prep: 353.310, (0.185 sec/step)\n",
      "step 460, loss = 6.064288, perp: 361.414, dev_prep: 375.297, (0.186 sec/step)\n",
      "step 480, loss = 6.064642, perp: 356.037, dev_prep: 380.801, (0.183 sec/step)\n",
      "step 500, loss = 6.098613, perp: 361.496, dev_prep: 372.791, (0.185 sec/step)\n",
      "step 520, loss = 6.127523, perp: 376.625, dev_prep: 385.062, (0.184 sec/step)\n",
      "step 540, loss = 6.047297, perp: 344.658, dev_prep: 317.565, (0.187 sec/step)\n",
      "step 560, loss = 6.079378, perp: 364.850, dev_prep: 365.298, (0.184 sec/step)\n",
      "step 580, loss = 6.147167, perp: 387.808, dev_prep: 363.081, (0.183 sec/step)\n",
      "step 600, loss = 6.107033, perp: 373.649, dev_prep: 346.314, (0.189 sec/step)\n",
      "step 620, loss = 6.074249, perp: 366.439, dev_prep: 377.032, (0.188 sec/step)\n",
      "step 640, loss = 5.990425, perp: 338.603, dev_prep: 343.356, (0.179 sec/step)\n",
      "step 660, loss = 6.012284, perp: 335.245, dev_prep: 374.811, (0.187 sec/step)\n",
      "step 680, loss = 6.021513, perp: 314.055, dev_prep: 315.936, (0.179 sec/step)\n",
      "step 700, loss = 5.934940, perp: 308.768, dev_prep: 302.294, (0.183 sec/step)\n",
      "step 720, loss = 5.761100, perp: 257.493, dev_prep: 278.156, (0.192 sec/step)\n",
      "step 740, loss = 5.724792, perp: 249.293, dev_prep: 249.417, (0.187 sec/step)\n",
      "step 760, loss = 5.706587, perp: 263.905, dev_prep: 269.997, (0.186 sec/step)\n",
      "step 780, loss = 5.779139, perp: 261.493, dev_prep: 274.679, (0.190 sec/step)\n",
      "step 800, loss = 5.667672, perp: 248.761, dev_prep: 270.300, (0.189 sec/step)\n",
      "step 820, loss = 5.629139, perp: 220.769, dev_prep: 234.374, (0.182 sec/step)\n",
      "step 840, loss = 5.613652, perp: 236.594, dev_prep: 243.357, (0.191 sec/step)\n",
      "step 860, loss = 5.552343, perp: 211.699, dev_prep: 233.905, (0.181 sec/step)\n",
      "step 880, loss = 5.532242, perp: 210.051, dev_prep: 227.245, (0.189 sec/step)\n",
      "step 900, loss = 5.563748, perp: 214.746, dev_prep: 217.878, (0.187 sec/step)\n",
      "step 920, loss = 5.499141, perp: 211.956, dev_prep: 222.475, (0.183 sec/step)\n",
      "step 940, loss = 5.628675, perp: 219.841, dev_prep: 227.624, (0.191 sec/step)\n",
      "step 960, loss = 5.509669, perp: 202.866, dev_prep: 211.244, (0.189 sec/step)\n",
      "step 980, loss = 5.468205, perp: 200.421, dev_prep: 203.053, (0.182 sec/step)\n",
      "step 1000, loss = 5.443477, perp: 186.955, dev_prep: 190.339, (0.187 sec/step)\n",
      "Storing checkpoint to ./log_ECM/ ... Done.\n",
      "step 1020, loss = 5.421318, perp: 193.446, dev_prep: 194.809, (0.190 sec/step)\n",
      "step 1040, loss = 5.441096, perp: 188.522, dev_prep: 184.666, (0.189 sec/step)\n",
      "step 1060, loss = 5.348361, perp: 177.523, dev_prep: 184.724, (0.189 sec/step)\n",
      "step 1080, loss = 5.346297, perp: 173.700, dev_prep: 184.256, (0.182 sec/step)\n",
      "step 1100, loss = 5.339422, perp: 168.505, dev_prep: 170.122, (0.165 sec/step)\n",
      "step 1120, loss = 5.317061, perp: 170.274, dev_prep: 176.452, (0.187 sec/step)\n",
      "step 1140, loss = 5.289461, perp: 171.066, dev_prep: 174.794, (0.189 sec/step)\n",
      "step 1160, loss = 5.260417, perp: 163.334, dev_prep: 165.963, (0.186 sec/step)\n",
      "step 1180, loss = 5.293383, perp: 165.959, dev_prep: 165.058, (0.183 sec/step)\n",
      "step 1200, loss = 5.319636, perp: 171.486, dev_prep: 175.763, (0.194 sec/step)\n",
      "step 1220, loss = 5.273771, perp: 162.401, dev_prep: 159.233, (0.183 sec/step)\n",
      "step 1240, loss = 5.255383, perp: 156.793, dev_prep: 156.052, (0.186 sec/step)\n",
      "step 1260, loss = 5.324623, perp: 150.741, dev_prep: 156.712, (0.189 sec/step)\n",
      "step 1280, loss = 5.196091, perp: 158.776, dev_prep: 157.875, (0.182 sec/step)\n",
      "step 1300, loss = 5.154405, perp: 144.983, dev_prep: 146.116, (0.185 sec/step)\n",
      "step 1320, loss = 5.221665, perp: 152.431, dev_prep: 146.706, (0.183 sec/step)\n",
      "step 1340, loss = 5.167669, perp: 146.781, dev_prep: 152.595, (0.187 sec/step)\n",
      "step 1360, loss = 5.147549, perp: 146.348, dev_prep: 147.178, (0.181 sec/step)\n",
      "step 1380, loss = 5.149224, perp: 144.063, dev_prep: 148.795, (0.182 sec/step)\n",
      "step 1400, loss = 5.161775, perp: 147.036, dev_prep: 149.862, (0.185 sec/step)\n",
      "step 1420, loss = 5.158008, perp: 145.615, dev_prep: 146.720, (0.193 sec/step)\n",
      "step 1440, loss = 5.095107, perp: 138.453, dev_prep: 147.558, (0.189 sec/step)\n",
      "step 1460, loss = 5.110743, perp: 146.595, dev_prep: 155.549, (0.190 sec/step)\n",
      "step 1480, loss = 5.115709, perp: 134.465, dev_prep: 142.322, (0.183 sec/step)\n",
      "step 1500, loss = 5.128223, perp: 137.091, dev_prep: 138.902, (0.185 sec/step)\n",
      "step 1520, loss = 5.125191, perp: 134.628, dev_prep: 142.874, (0.187 sec/step)\n",
      "step 1540, loss = 5.034696, perp: 138.651, dev_prep: 142.585, (0.186 sec/step)\n",
      "step 1560, loss = 5.052518, perp: 131.320, dev_prep: 136.077, (0.190 sec/step)\n",
      "step 1580, loss = 5.049092, perp: 128.070, dev_prep: 136.688, (0.183 sec/step)\n",
      "step 1600, loss = 5.023190, perp: 129.598, dev_prep: 127.784, (0.186 sec/step)\n",
      "step 1620, loss = 5.011702, perp: 123.527, dev_prep: 134.418, (0.185 sec/step)\n",
      "step 1640, loss = 4.972788, perp: 119.564, dev_prep: 118.792, (0.189 sec/step)\n",
      "step 1660, loss = 4.944819, perp: 114.797, dev_prep: 124.753, (0.182 sec/step)\n",
      "step 1680, loss = 4.986906, perp: 118.391, dev_prep: 122.841, (0.182 sec/step)\n",
      "step 1700, loss = 4.924162, perp: 115.334, dev_prep: 120.212, (0.179 sec/step)\n",
      "step 1720, loss = 5.000564, perp: 120.495, dev_prep: 123.440, (0.184 sec/step)\n",
      "step 1740, loss = 4.951468, perp: 113.611, dev_prep: 118.430, (0.162 sec/step)\n",
      "step 1760, loss = 4.926950, perp: 113.751, dev_prep: 119.284, (0.187 sec/step)\n",
      "step 1780, loss = 4.926433, perp: 119.986, dev_prep: 123.301, (0.189 sec/step)\n",
      "step 1800, loss = 4.906546, perp: 113.137, dev_prep: 121.087, (0.182 sec/step)\n",
      "step 1820, loss = 4.942041, perp: 117.859, dev_prep: 125.200, (0.185 sec/step)\n",
      "step 1840, loss = 4.880840, perp: 112.094, dev_prep: 122.996, (0.182 sec/step)\n",
      "step 1860, loss = 4.900367, perp: 114.449, dev_prep: 119.402, (0.163 sec/step)\n",
      "step 1880, loss = 4.897328, perp: 114.741, dev_prep: 120.160, (0.182 sec/step)\n",
      "step 1900, loss = 4.873522, perp: 109.517, dev_prep: 119.552, (0.181 sec/step)\n",
      "step 1920, loss = 4.894629, perp: 109.931, dev_prep: 116.454, (0.190 sec/step)\n",
      "step 1940, loss = 4.860853, perp: 109.179, dev_prep: 118.305, (0.188 sec/step)\n",
      "step 1960, loss = 4.936251, perp: 115.919, dev_prep: 121.279, (0.188 sec/step)\n",
      "step 1980, loss = 4.908811, perp: 110.292, dev_prep: 113.885, (0.182 sec/step)\n",
      "step 2000, loss = 4.841141, perp: 106.377, dev_prep: 122.613, (0.188 sec/step)\n",
      "Storing checkpoint to ./log_ECM/ ... Done.\n",
      "step 2020, loss = 4.920382, perp: 115.862, dev_prep: 120.383, (0.185 sec/step)\n",
      "step 2040, loss = 4.876149, perp: 112.446, dev_prep: 123.594, (0.194 sec/step)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 2060, loss = 4.927907, perp: 110.146, dev_prep: 118.749, (0.187 sec/step)\n",
      "step 2080, loss = 4.896173, perp: 110.470, dev_prep: 115.716, (0.194 sec/step)\n",
      "step 2100, loss = 4.868523, perp: 109.953, dev_prep: 114.307, (0.182 sec/step)\n",
      "step 2120, loss = 4.875549, perp: 109.977, dev_prep: 110.834, (0.189 sec/step)\n",
      "step 2140, loss = 4.881153, perp: 108.387, dev_prep: 110.195, (0.191 sec/step)\n",
      "step 2160, loss = 4.828449, perp: 104.490, dev_prep: 109.685, (0.188 sec/step)\n",
      "step 2180, loss = 4.831602, perp: 102.629, dev_prep: 107.534, (0.185 sec/step)\n",
      "step 2200, loss = 4.820033, perp: 103.544, dev_prep: 106.777, (0.192 sec/step)\n",
      "step 2220, loss = 4.782588, perp: 98.046, dev_prep: 101.781, (0.183 sec/step)\n",
      "step 2240, loss = 4.758969, perp: 95.714, dev_prep: 102.943, (0.187 sec/step)\n",
      "step 2260, loss = 4.794845, perp: 98.235, dev_prep: 103.815, (0.183 sec/step)\n",
      "step 2280, loss = 4.729478, perp: 95.127, dev_prep: 101.770, (0.191 sec/step)\n",
      "step 2300, loss = 4.740843, perp: 95.360, dev_prep: 103.751, (0.185 sec/step)\n",
      "step 2320, loss = 4.729052, perp: 94.886, dev_prep: 107.215, (0.190 sec/step)\n",
      "step 2340, loss = 4.790201, perp: 102.502, dev_prep: 102.169, (0.187 sec/step)\n",
      "step 2360, loss = 4.776232, perp: 98.504, dev_prep: 102.414, (0.183 sec/step)\n",
      "step 2380, loss = 4.771354, perp: 94.238, dev_prep: 99.767, (0.179 sec/step)\n",
      "step 2400, loss = 4.743855, perp: 94.906, dev_prep: 98.412, (0.187 sec/step)\n",
      "step 2420, loss = 4.760398, perp: 95.788, dev_prep: 98.236, (0.181 sec/step)\n",
      "step 2440, loss = 4.716070, perp: 95.361, dev_prep: 98.004, (0.179 sec/step)\n",
      "step 2460, loss = 4.733517, perp: 100.446, dev_prep: 100.645, (0.185 sec/step)\n",
      "step 2480, loss = 4.659625, perp: 89.647, dev_prep: 96.115, (0.187 sec/step)\n",
      "step 2500, loss = 4.702543, perp: 91.974, dev_prep: 93.192, (0.187 sec/step)\n",
      "step 2520, loss = 4.656796, perp: 89.447, dev_prep: 93.608, (0.182 sec/step)\n",
      "step 2540, loss = 4.693813, perp: 90.302, dev_prep: 89.505, (0.186 sec/step)\n",
      "step 2560, loss = 4.705162, perp: 94.846, dev_prep: 95.162, (0.083 sec/step)\n",
      "step 2580, loss = 4.632351, perp: 86.151, dev_prep: 95.031, (0.080 sec/step)\n",
      "step 2600, loss = 4.656510, perp: 88.829, dev_prep: 93.670, (0.078 sec/step)\n",
      "step 2620, loss = 4.658534, perp: 88.956, dev_prep: 89.515, (0.079 sec/step)\n",
      "step 2640, loss = 4.622330, perp: 83.819, dev_prep: 93.166, (0.082 sec/step)\n",
      "step 2660, loss = 4.585625, perp: 83.894, dev_prep: 89.330, (0.079 sec/step)\n",
      "step 2680, loss = 4.644842, perp: 86.917, dev_prep: 94.183, (0.083 sec/step)\n",
      "step 2700, loss = 4.633335, perp: 88.337, dev_prep: 92.091, (0.079 sec/step)\n",
      "step 2720, loss = 4.699011, perp: 89.642, dev_prep: 91.290, (0.079 sec/step)\n",
      "step 2740, loss = 4.645009, perp: 85.878, dev_prep: 89.038, (0.080 sec/step)\n",
      "step 2760, loss = 4.749645, perp: 95.479, dev_prep: 94.494, (0.079 sec/step)\n",
      "step 2780, loss = 4.607339, perp: 84.594, dev_prep: 91.236, (0.084 sec/step)\n",
      "step 2800, loss = 4.620332, perp: 82.940, dev_prep: 90.569, (0.079 sec/step)\n",
      "step 2820, loss = 4.597293, perp: 81.907, dev_prep: 87.513, (0.079 sec/step)\n",
      "step 2840, loss = 4.617379, perp: 88.268, dev_prep: 89.192, (0.079 sec/step)\n",
      "step 2860, loss = 4.600885, perp: 83.437, dev_prep: 85.890, (0.081 sec/step)\n",
      "step 2880, loss = 4.573946, perp: 81.296, dev_prep: 88.377, (0.085 sec/step)\n",
      "step 2900, loss = 4.592528, perp: 83.249, dev_prep: 85.875, (0.167 sec/step)\n",
      "step 2920, loss = 4.604003, perp: 83.089, dev_prep: 82.393, (0.080 sec/step)\n",
      "step 2940, loss = 4.582617, perp: 81.753, dev_prep: 86.923, (0.080 sec/step)\n",
      "step 2960, loss = 4.608225, perp: 82.776, dev_prep: 86.733, (0.081 sec/step)\n",
      "step 2980, loss = 4.578659, perp: 80.656, dev_prep: 86.371, (0.081 sec/step)\n",
      "step 3000, loss = 4.516659, perp: 78.021, dev_prep: 86.935, (0.080 sec/step)\n",
      "Storing checkpoint to ./log_ECM/ ... Done.\n",
      "step 3020, loss = 4.559056, perp: 80.956, dev_prep: 88.171, (0.079 sec/step)\n",
      "step 3040, loss = 4.531827, perp: 77.965, dev_prep: 83.728, (0.080 sec/step)\n",
      "step 3060, loss = 4.578849, perp: 79.982, dev_prep: 84.877, (0.081 sec/step)\n",
      "step 3080, loss = 4.530557, perp: 77.949, dev_prep: 82.308, (0.080 sec/step)\n",
      "step 3100, loss = 4.535232, perp: 78.006, dev_prep: 81.373, (0.081 sec/step)\n",
      "step 3120, loss = 4.482614, perp: 75.954, dev_prep: 79.441, (0.083 sec/step)\n",
      "step 3140, loss = 4.514441, perp: 76.751, dev_prep: 80.105, (0.082 sec/step)\n",
      "step 3160, loss = 4.525198, perp: 76.230, dev_prep: 83.383, (0.079 sec/step)\n",
      "step 3180, loss = 4.498977, perp: 75.702, dev_prep: 81.148, (0.079 sec/step)\n",
      "step 3200, loss = 4.568997, perp: 82.044, dev_prep: 80.345, (0.079 sec/step)\n",
      "step 3220, loss = 4.509623, perp: 76.135, dev_prep: 82.182, (0.078 sec/step)\n",
      "step 3240, loss = 4.502332, perp: 77.687, dev_prep: 80.732, (0.081 sec/step)\n",
      "step 3260, loss = 4.530997, perp: 77.289, dev_prep: 80.352, (0.078 sec/step)\n",
      "step 3280, loss = 4.481592, perp: 74.368, dev_prep: 80.587, (0.080 sec/step)\n",
      "step 3300, loss = 4.543622, perp: 79.445, dev_prep: 81.863, (0.080 sec/step)\n",
      "step 3320, loss = 4.511406, perp: 75.169, dev_prep: 81.047, (0.080 sec/step)\n",
      "step 3340, loss = 4.500355, perp: 77.342, dev_prep: 81.587, (0.080 sec/step)\n",
      "step 3360, loss = 4.573911, perp: 78.836, dev_prep: 80.589, (0.081 sec/step)\n",
      "step 3380, loss = 4.553513, perp: 79.910, dev_prep: 81.086, (0.078 sec/step)\n",
      "step 3400, loss = 4.532192, perp: 75.592, dev_prep: 78.304, (0.082 sec/step)\n",
      "step 3420, loss = 4.500751, perp: 79.848, dev_prep: 83.318, (0.078 sec/step)\n",
      "step 3440, loss = 4.535955, perp: 77.259, dev_prep: 81.985, (0.079 sec/step)\n",
      "step 3460, loss = 4.476738, perp: 73.482, dev_prep: 77.843, (0.079 sec/step)\n",
      "step 3480, loss = 4.511874, perp: 77.115, dev_prep: 77.682, (0.078 sec/step)\n",
      "step 3500, loss = 4.457969, perp: 73.108, dev_prep: 76.463, (0.080 sec/step)\n",
      "step 3520, loss = 4.462436, perp: 73.593, dev_prep: 77.071, (0.083 sec/step)\n",
      "step 3540, loss = 4.527004, perp: 77.215, dev_prep: 75.091, (0.080 sec/step)\n",
      "step 3560, loss = 4.493571, perp: 74.579, dev_prep: 78.790, (0.080 sec/step)\n",
      "step 3580, loss = 4.496834, perp: 73.610, dev_prep: 78.710, (0.081 sec/step)\n",
      "step 3600, loss = 4.504863, perp: 73.854, dev_prep: 81.691, (0.081 sec/step)\n",
      "step 3620, loss = 4.446844, perp: 72.436, dev_prep: 76.554, (0.080 sec/step)\n",
      "step 3640, loss = 4.487943, perp: 73.651, dev_prep: 75.949, (0.079 sec/step)\n",
      "step 3660, loss = 4.420169, perp: 70.448, dev_prep: 74.066, (0.079 sec/step)\n",
      "step 3680, loss = 4.450901, perp: 71.031, dev_prep: 76.457, (0.081 sec/step)\n",
      "step 3700, loss = 4.424866, perp: 71.765, dev_prep: 77.775, (0.079 sec/step)\n",
      "step 3720, loss = 4.425903, perp: 71.046, dev_prep: 76.727, (0.079 sec/step)\n",
      "step 3740, loss = 4.420540, perp: 68.445, dev_prep: 70.922, (0.080 sec/step)\n",
      "step 3760, loss = 4.446115, perp: 71.495, dev_prep: 73.187, (0.081 sec/step)\n",
      "step 3780, loss = 4.413393, perp: 67.853, dev_prep: 70.986, (0.080 sec/step)\n",
      "step 3800, loss = 4.360610, perp: 65.886, dev_prep: 70.382, (0.081 sec/step)\n",
      "step 3820, loss = 4.387936, perp: 66.835, dev_prep: 71.592, (0.080 sec/step)\n",
      "step 3840, loss = 4.380410, perp: 65.714, dev_prep: 72.976, (0.079 sec/step)\n",
      "step 3860, loss = 4.371659, perp: 67.541, dev_prep: 73.697, (0.078 sec/step)\n",
      "step 3880, loss = 4.357558, perp: 65.835, dev_prep: 70.933, (0.079 sec/step)\n",
      "step 3900, loss = 4.329628, perp: 62.264, dev_prep: 65.040, (0.079 sec/step)\n",
      "step 3920, loss = 4.360806, perp: 65.271, dev_prep: 68.088, (0.080 sec/step)\n",
      "step 3940, loss = 4.328600, perp: 65.648, dev_prep: 67.554, (0.081 sec/step)\n",
      "step 3960, loss = 4.287574, perp: 62.409, dev_prep: 68.826, (0.079 sec/step)\n",
      "step 3980, loss = 4.366633, perp: 64.635, dev_prep: 66.600, (0.082 sec/step)\n",
      "step 4000, loss = 4.300029, perp: 62.531, dev_prep: 66.974, (0.078 sec/step)\n",
      "Storing checkpoint to ./log_ECM/ ... Done.\n",
      "step 4020, loss = 4.309150, perp: 61.121, dev_prep: 66.885, (0.080 sec/step)\n",
      "step 4040, loss = 4.277199, perp: 60.446, dev_prep: 67.561, (0.079 sec/step)\n",
      "step 4060, loss = 4.279891, perp: 61.647, dev_prep: 64.620, (0.080 sec/step)\n",
      "step 4080, loss = 4.262224, perp: 60.627, dev_prep: 64.337, (0.082 sec/step)\n",
      "step 4100, loss = 4.264299, perp: 59.753, dev_prep: 63.343, (0.079 sec/step)\n",
      "step 4120, loss = 4.200579, perp: 57.031, dev_prep: 63.125, (0.080 sec/step)\n",
      "step 4140, loss = 4.236899, perp: 57.046, dev_prep: 60.047, (0.080 sec/step)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 4160, loss = 4.248465, perp: 58.336, dev_prep: 61.721, (0.080 sec/step)\n",
      "step 4180, loss = 4.245794, perp: 59.106, dev_prep: 62.360, (0.090 sec/step)\n",
      "step 4200, loss = 4.131330, perp: 54.075, dev_prep: 59.100, (0.080 sec/step)\n",
      "step 4220, loss = 4.205660, perp: 56.970, dev_prep: 63.348, (0.079 sec/step)\n",
      "step 4240, loss = 4.219598, perp: 56.288, dev_prep: 59.757, (0.079 sec/step)\n",
      "step 4260, loss = 4.219422, perp: 59.529, dev_prep: 62.943, (0.080 sec/step)\n",
      "step 4280, loss = 4.178509, perp: 54.653, dev_prep: 58.011, (0.081 sec/step)\n",
      "step 4300, loss = 4.176664, perp: 56.948, dev_prep: 62.351, (0.081 sec/step)\n",
      "step 4320, loss = 4.201800, perp: 55.264, dev_prep: 60.695, (0.079 sec/step)\n",
      "step 4340, loss = 4.201328, perp: 55.921, dev_prep: 61.883, (0.081 sec/step)\n",
      "step 4360, loss = 4.151330, perp: 55.809, dev_prep: 62.090, (0.080 sec/step)\n",
      "step 4380, loss = 4.170990, perp: 53.314, dev_prep: 57.803, (0.080 sec/step)\n",
      "step 4400, loss = 4.191956, perp: 55.255, dev_prep: 55.121, (0.079 sec/step)\n",
      "step 4420, loss = 4.120603, perp: 53.695, dev_prep: 56.106, (0.081 sec/step)\n",
      "step 4440, loss = 4.158020, perp: 54.133, dev_prep: 55.951, (0.081 sec/step)\n",
      "step 4460, loss = 4.143220, perp: 53.169, dev_prep: 56.949, (0.081 sec/step)\n",
      "step 4480, loss = 4.146042, perp: 52.620, dev_prep: 59.701, (0.079 sec/step)\n",
      "step 4500, loss = 4.128221, perp: 51.466, dev_prep: 54.707, (0.080 sec/step)\n",
      "step 4520, loss = 4.085796, perp: 50.922, dev_prep: 53.523, (0.080 sec/step)\n",
      "step 4540, loss = 4.108714, perp: 52.554, dev_prep: 54.882, (0.079 sec/step)\n",
      "step 4560, loss = 4.064867, perp: 49.295, dev_prep: 54.089, (0.079 sec/step)\n",
      "step 4580, loss = 4.017982, perp: 48.182, dev_prep: 54.323, (0.081 sec/step)\n",
      "step 4600, loss = 4.061761, perp: 48.718, dev_prep: 50.857, (0.079 sec/step)\n",
      "step 4620, loss = 4.038851, perp: 47.804, dev_prep: 50.335, (0.081 sec/step)\n",
      "step 4640, loss = 4.062167, perp: 48.426, dev_prep: 52.050, (0.080 sec/step)\n",
      "step 4660, loss = 4.041481, perp: 48.472, dev_prep: 51.895, (0.078 sec/step)\n",
      "step 4680, loss = 4.036169, perp: 48.464, dev_prep: 53.689, (0.078 sec/step)\n",
      "step 4700, loss = 3.990679, perp: 46.216, dev_prep: 51.495, (0.080 sec/step)\n",
      "step 4720, loss = 4.024677, perp: 48.344, dev_prep: 51.805, (0.078 sec/step)\n",
      "step 4740, loss = 4.020321, perp: 47.545, dev_prep: 51.100, (0.080 sec/step)\n",
      "step 4760, loss = 3.983739, perp: 46.289, dev_prep: 50.346, (0.081 sec/step)\n",
      "step 4780, loss = 3.997794, perp: 45.867, dev_prep: 49.351, (0.080 sec/step)\n",
      "step 4800, loss = 3.973680, perp: 46.130, dev_prep: 49.737, (0.078 sec/step)\n",
      "step 4820, loss = 3.985074, perp: 45.821, dev_prep: 47.821, (0.083 sec/step)\n",
      "step 4840, loss = 3.968639, perp: 45.215, dev_prep: 48.985, (0.080 sec/step)\n",
      "step 4860, loss = 3.926841, perp: 43.686, dev_prep: 49.580, (0.079 sec/step)\n",
      "step 4880, loss = 3.886644, perp: 41.428, dev_prep: 44.427, (0.080 sec/step)\n",
      "step 4900, loss = 3.905177, perp: 44.362, dev_prep: 47.954, (0.081 sec/step)\n",
      "step 4920, loss = 3.948175, perp: 42.487, dev_prep: 45.684, (0.080 sec/step)\n",
      "step 4940, loss = 3.842580, perp: 40.326, dev_prep: 45.371, (0.078 sec/step)\n",
      "step 4960, loss = 3.881011, perp: 40.816, dev_prep: 46.718, (0.080 sec/step)\n",
      "step 4980, loss = 3.919305, perp: 42.109, dev_prep: 46.560, (0.079 sec/step)\n",
      "Storing checkpoint to ./log_ECM/ ... Done.\n"
     ]
    }
   ],
   "source": [
    "last_saved_step = saved_global_step\n",
    "num_steps = saved_global_step + train_steps\n",
    "losses = []\n",
    "steps = []\n",
    "perps = []\n",
    "dev_perps = []\n",
    "\n",
    "print(\"Start training ...\")\n",
    "try:\n",
    "    for step in range(saved_global_step + 1, num_steps):\n",
    "        start_time = time.time()\n",
    "        rand_indexes = np.random.choice(n_data, batch_size)\n",
    "        source_batch = source_data[rand_indexes]\n",
    "        target_batch = target_data[rand_indexes]\n",
    "        choice_batch = choice_data[rand_indexes]\n",
    "        emotions = category_data[rand_indexes]\n",
    "        mask_batch = masks[rand_indexes]\n",
    "\n",
    "        feed_dict = {\n",
    "            source_ids: source_batch,\n",
    "            target_ids: target_batch,\n",
    "            choice_qs: choice_batch,\n",
    "            emo_cat: emotions,\n",
    "            sequence_mask: mask_batch,\n",
    "        }\n",
    "\n",
    "        loss_value, _ = sess.run([loss, optim], feed_dict=feed_dict)\n",
    "        losses.append(loss_value)\n",
    "\n",
    "        duration = time.time() - start_time\n",
    "\n",
    "        if step % print_every == 0:\n",
    "            # train perplexity\n",
    "            t_perp = compute_perplexity(sess, CE, mask_batch, feed_dict)\n",
    "            perps.append(t_perp)\n",
    "\n",
    "            # dev perplexity\n",
    "            dev_str = \"\"\n",
    "            if dev_source_data is not None:\n",
    "                dev_inds = np.random.choice(\n",
    "                    len(dev_source_data), batch_size)\n",
    "\n",
    "                dev_feed_dict = {\n",
    "                    source_ids: dev_source_data[dev_inds],\n",
    "                    target_ids: dev_target_data[dev_inds],\n",
    "                    choice_qs: dev_choice_data[dev_inds],\n",
    "                    emo_cat: dev_category_data[dev_inds],\n",
    "                    sequence_mask: dev_masks[dev_inds],\n",
    "                }\n",
    "\n",
    "                dev_perp = compute_perplexity(\n",
    "                    sess, CE, dev_masks[dev_inds], dev_feed_dict)\n",
    "                dev_perps.append(dev_perp)\n",
    "                dev_str = \"dev_prep: {:.3f}, \".format(dev_perp)\n",
    "\n",
    "            steps.append(step)\n",
    "            info = 'step {:d}, loss = {:.6f}, '\n",
    "            info += 'perp: {:.3f}, {}({:.3f} sec/step)'\n",
    "            print(info.format(step, loss_value, t_perp, dev_str, duration))\n",
    "\n",
    "        if step % checkpoint_every == 0:\n",
    "            save(saver, sess, logdir, step)\n",
    "            last_saved_step = step\n",
    "\n",
    "except KeyboardInterrupt:\n",
    "    # Introduce a line break after ^C so save message is on its own line.\n",
    "    print()\n",
    "\n",
    "finally:\n",
    "    if step > last_saved_step:\n",
    "        save(saver, sess, logdir, step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEWCAYAAACdaNcBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl4VdXVx/HvykSYxyAzCCKooAgRAScUUQSrra91qLXWaqnWaq3W1rlqHbC21qqt1lqt2tbaOpUqogxqrQIaZhAQkCAzYQpTINN6/7iHkIQMN8m9ubk3v8/z5MkZ9t17HZ6wsrPPOXubuyMiIoklKdYBiIhI5Cm5i4gkICV3EZEEpOQuIpKAlNxFRBKQkruISAJSchcpxczSzczNrFsl52ea2bfrOy6RmlJylwbPzHaX+io2s7xS+5dV89kxZraivmIVaShSYh2ASHXcvcWBbTPLBq5296mxi0ik4VPPXeKemTU1s9+b2QYzW2tmj5hZqpm1B94Aepfq6bc3s5PMbJaZ7TCz9Wb2WzOrcUfHzJLN7F4z+8rMNpnZc2bWMjjX3Mz+YWbbgnZmmVnb4Nz3zSzbzHaZ2Zdm9s3I/ouIKLlLYrgXOBYYCAwBRgI/c/etwDeAL929RfC1FSgAfgS0B04BvgZcXYt2fwBcFNTRF+gIPBqcu5rQX8ZdgQ5Be/lBgn8EGOXuLYGTgUW1aFukSkrukgguA37h7lvcfRNwP3B5ZYXd/VN3/8zdi9x9JfAscFot233E3Ve7+07gDuAyMzNCv0AygD7uXhi0t6fUZweYWbq7r3f3JbVoW6RKSu4S14JE2glYXerwakI95so+c7SZvRMMpewE7ibUu66pLhW02xRoB/wZ+BB4NRgqetDMkt19O6FfCjcAG81sopkdUYu2Raqk5C5xzUPTmm4EepY63ANYd6BIBR/7EzCHUK+6FXAfYLVofn0F7eYB29x9v7vf7e79gVOBbwKXBDG/7e6jCP1y+Ap4qhZti1RJyV0SwcvAL4KbpR0JDY/8NTi3CehoZi1KlW8J5Lr7bjM7Bvh+Hdr9qZn1CG6k3g/83d3dzM4M/kJIAnYChUCxmXU1s3Fm1gzYD+wGimvZvkillNwlEdwNfA4sBuYBHwO/Cs7NByYCq4OnVtoBPwGuNrPdwO+BV2rZ7lPA68AnwEpgG3BTcK4r8G9gF6EbppOCdpKBWwn9tbEVOIHQzVaRiDIt1iEiknjUcxcRSUBK7iIiCUjJXUQkASm5i4gkoJhNHNahQwfv1atXrJoXEYlLs2fP3uLuGdWVi1ly79WrF1lZWbFqXkQkLpnZ6upLaVhGRCQhKbmLiCQgJXcRkQSk5C4ikoCU3EVEEpCSu4hIAlJyFxFJQHGX3Jdt3MWv313Gtj35sQ5FRKTBirvkvmrLbp58fwUbc/fFOhQRkQYr7pJ7y/RUAHbtK4hxJCIiDVfcJfdWQXLfua8wxpGIiDRccZfcW6aHpsNRz11EpHJxl9xbNT0wLKOeu4hIZeIuuR/oua/YvDvGkYiINFxxl9xTk0MhvzQzrFkvRUQapWqTu5n1M7N5pb52mtmN5cqMNLPcUmXujl7IIiJSnWoX63D3ZcAgADNLBtYBb1RQ9CN3Pzey4VXty5zd9M5oUZ9NiojEhZoOy4wCVrp7gxgTueL5T2MdgohIg1TT5H4J8HIl54ab2Xwze8fMjqmogJmNN7MsM8vKycmpYdMHPXbxIADWbMurdR0iIoks7ORuZmnAecC/Kjg9B+jp7scBTwBvVlSHuz/j7pnunpmRUe36rpUaM6ATAF8f1KXWdYiIJLKa9NzPAea4+6byJ9x9p7vvDrYnAalm1iFCMR4iPTWZts1SaZEes/W9RUQatJok90upZEjGzDqZmQXbQ4N6t9Y9vMpt31vAv+etj2YTIiJxK6yur5k1B0YDPyh17BoAd38auBC41swKgTzgEnf3yIdblt5SFRGpWFjJ3d33AO3LHXu61PaTwJORDa1qx/dow8K1ufXZpIhI3Ii7N1QPOLpzK1oH88yIiEhZcZvc01KS2Lonn+lLD7m/KyLS6MVtcv/Hp2sA+N5fsmIciYhIwxO3yT2voCjWIYiINFhxm9xvGn1krEMQEWmw4ja5/+j0IwBo2UQvMomIlBe3mTEpyTiuextaNEmOdSgiIg1O3CZ3gPlrdgCQX1hMWkr9/xHyymdf0b1dM0b0idpMCyIitRLXyf2AfYVFMUnuP39tIQDZE8bVe9siIlWJ2zH30goKi2MdgohIgxLXyf2XXx8AaLFsEZHy4jq53/XmIgAufmZmjCMREWlY4jq5XzC4a6xDEBFpkOI6ud84KvQi04mHt4txJCIiDUtcJ/ce7ZsBMGvVthhHIiLSsMR1chcRkYopuYuIJKBqk7uZ9TOzeaW+dprZjeXKmJk9bmYrzGyBmQ2OXsgiIlKdapO7uy9z90HuPggYAuwF3ihX7Bygb/A1Hngq0oFWJrNn2wNx1leTIiINXk2HZUYBK919dbnj5wMveshMoI2ZdY5IhNXo1rYpADu1WLaISImaJvdLgJcrON4VWFNqf21wrAwzG29mWWaWlZOTU8OmK3bqkRkAbN+TH5H6REQSQdjJ3czSgPOAf9W2MXd/xt0z3T0zIyOjttWU0So9tEh2bl5BROoTEUkENem5nwPMcfeKVqReB3Qvtd8tOBZ1zdJC87nvzdeyeyIiB9QkuV9KxUMyABOB7wRPzQwDct19Q52jC0OzYCWmvAKNuYuIHBBWcjez5sBo4PVSx64xs2uC3UnAl8AK4E/ADyMcZ6UO9NzfmLu+vpoUEWnwwlqsw933AO3LHXu61LYD10U2tPB0bNkEgCmfb4xF8yIiDVLcr8TUplkaGS2bcEpfLXUnInJAQkw/4A478zTmLiJyQEIk9y279zN1SUUP8YiINE4JkdxFRKSshEjulw/rSZtmqbEOQ0SkwUiI5N6mWSq5eQUUF2vyMBERSJDk3rppKu6wa79uqoqIQAIld4DcvZpfRkQEEiy579yn5C4iAgmS3FsGM0Nu2rkvxpGIiDQMCZHck5MMgKteyIpxJCIiDUNCJPeUZIt1CCIiDUpCJPeBXVvHOgQRkQYlIZJ7anJCXIaISMQoK4qIJCAldxGRBBTuSkxtzOxVM1tqZkvMbHi58yPNLNfM5gVfd0cn3MqN6NOeE3q1re9mRUQapHAX6/gdMNndLzSzNKBZBWU+cvdzIxdazcz9agd5BVokW0QEwkjuZtYaOBX4LoC75wP50Q2r5pTYRUQOCmdY5nAgB3jezOaa2bPBgtnlDTez+Wb2jpkdU1FFZjbezLLMLCsnJ6cucYuISBXCSe4pwGDgKXc/HtgD3FquzBygp7sfBzwBvFlRRe7+jLtnuntmRkZGHcI+1PhTe5OeqvvDIiIQXnJfC6x191nB/quEkn0Jd9/p7ruD7UlAqpnV64rVaclJ5BcW12eTIiINVrXJ3d03AmvMrF9waBTweekyZtbJzCzYHhrUuzXCsVapoKiYYod9GnsXEQn7Offrgb+Z2QJgEPCgmV1jZtcE5y8EFpnZfOBx4BJ3r9dlkf743y8BeP7j7PpsVkSkQQrrUUh3nwdkljv8dKnzTwJPRjCuGuvQoglbdu/n4clLuXZkn1iGIiIScwlzB/Kuc4+KdQgiIg1GwiT3VsGCHSIikkDJvVlacqxDEBFpMBImuZ/Qq12sQxARaTASJrknJWk1JhGRAxImuZc2e/V2AIqLnYKigy82bdm9n81aRFtEGoGETO7/99QnrNm2lxv+MZe+d7xTcjzz/qkMfXBaDCMTEakf4U75G3dO+dX7lZ7bviefguJiOrZMr8eIRETqT0L23Ms778n/sb/w4LQEJz88naEPTOPyP8/iDx+sKFP20feW8cS05fUdoohIRCVUcp90wykVHl+wNpeH31lWsr8nP5ToP1q+hV9NPnh8Y+4+Hp++gt9M+aLk2IyVW1mxeXeV7RYWacIyEWlYEmpY5ugurSo999zHqyo9997ijfQ9rCWn//qDkmOFRcW8OGM1970VmiNtQNdWvHX9wV8eO/YeXK9kQ+4+ureraHEqEZHYSKjkXlvjX5p9yLG7/r2Ylz/9qmR/0bqdXPTHGfx8TH8+37CTu95cVHJux94CmjfJp13zNCbOX0/rpqmcfEQHkvV4pojEiNXz5I0lMjMzPSsrK+L19rr17YjXGa4Lh3Tj1dlrS/azJ4yLWSwikpjMbLa7l5/I8RAJNeYOcMXwnjFru3RiFxGJpYRL7veeP4DlD5wT6zBERGIq4ZI7QGpyEtecFvs53T9ansOufQWxDkNEGqGETO4AbZvFfgrgy//8KT95ZV6swxCRRiis5G5mbczsVTNbamZLzGx4ufNmZo+b2QozW2Bmgyurq7587+TDeeziQbx1/ckxjWN5Nc/Ii4hEQ7g9998Bk929P3AcsKTc+XOAvsHXeOCpiEVYS6nJSXz9+K4M6NqaH8Zw2b0YPYwkIo1ctcndzFoDpwJ/BnD3fHffUa7Y+cCLHjITaGNmnSMebS39ZPSRvPHDESX7024+LaL1/+JrR1d6zlF2F5H6F07P/XAgB3jezOaa2bNm1rxcma7AmlL7a4NjZZjZeDPLMrOsnJycWgddU6nJSRzfoy1vXX8yd517NF3bNI1Y3a//cARXnnR4pefVcxeRWAgnuacAg4Gn3P14YA9wa20ac/dn3D3T3TMzMjJqU0WdDOjamqtOPpz01GRm3T6KO8fVfVHtwT3aVnl+7fY8dobxxMxXW/fy+LTlxOqlMhFJLOEk97XAWnefFey/SijZl7YO6F5qv1twrME6rFU65x3XhcNaNWHazaex4J6z+PLBsQB8c0i3knLXnd6H7510OLPvPJM/XDaY75R6SWpIz4OJffypvXnpqqHcd/4xh7R1TQXTG5R35V8+5dEpX7A+V4uJiEjdVTu3jLtvNLM1ZtbP3ZcBo4DPyxWbCPzIzP4BnAjkuvuGyIcbWR1bpTPr9jPLHDswZcDsr7Zzat8Mbjm7f8m5sQM706ZZKi/OWM2g7m147dqD4/i3jw39FXBK3wwen7acLbsPTiz2+Yad1cayryA0s2RxsXruIlJ34U4cdj3wNzNLA74ErjSzawDc/WlgEjAWWAHsBa6MQqz1avrNIys8PqJPB/4xfliVC3I/851MLvjDJyX7O/bqRSYRqV9hJXd3nweUn6jm6VLnHbgugnE1aMN6t6/yfJfWtb9ha5pIUkQiIGHfUI2lTq1rv3yf7qeKSCQoudcTPQUjIvVJyT1KRvYr+6jnz19bQFb2thhFIyKNjZJ7lFyU2b3M/j+z1nLh0zPYunt/jCISkcZEyT1Kxg7szDs/PnTB7iH3T+WFT7K59z+LYxCViDQWSu5RdFTnihfs/sXExTz/cXaFc71v3qWevYjUnZJ7DA25f2rJ9rodeQD8IIy3WUVEqqPkHkP5hcWHHNu6Rz13Eak7Jfcom3HbGVXOQplb7u1VPTEpIpGg5B5lnVs35ZrTeld6/uevLajHaESksVByrwd9Orao9NzkxRvrMRIRaSyU3OvBiD4dmH7zaYzq37HC8zdpEW0RiTAl93rSO6MFf7x8SIXnXp9bdur74mLn6hey+PazsygoOvSmq4hIdZTc61FKchKrHhpbbbmZX25l6pJN/G/FFj7TlAUiUgtK7vXMzJhwwUAe+MaASst869lZJdvJmgNYRGpByT0GLhnagwuO71Z9QUK/DEREaiqs5G5m2Wa20MzmmVlWBedHmllucH6emd0d+VATS9O0ZD68ZWS15S764wx27y+MfkAiklBq0nM/3d0HuXv5FZkO+Cg4P8jd74tEcImuZ/vm9GjXrNpyp//6g+gHIyIJRcMyMfbhLSN5ppKnaA7I2bWf37+/op4iEpFEEG5yd+A9M5ttZuMrKTPczOab2TtmdkyE4kt4ZsZZx3Tik1vPqLLcI+8uq6eIRCQRhJvcT3b3wcA5wHVmdmq583OAnu5+HPAE8GZFlZjZeDPLMrOsnJycWgediLqUm3/mWyf2OKTMJyu3aLk+EQlLWMnd3dcF3zcDbwBDy53f6e67g+1JQKqZdaignmfcPdPdMzMyMsqfbvQW3Xt2yfaD3xjIc98te3vjW3+axeG3TSqZHtjd+fCLHIqLyyb8zbv2sXzTrugHLCINVrXJ3cyam1nLA9vAWcCicmU6WfDMnpkNDerdGvlwE1uLJim8du0IHrnwWADO6H9YheVOmjCdaUs28fbCDVzx3Ke8NHN1mfMnT3if0b/9b9TjFZGGKyWMMocBbwS5OwX4u7tPNrNrANz9aeBC4FozKwTygEtc4we1MqRnW4b0bFuyf+6xnXlrwYZDyl31QhadWqUDoZWd/rdiC3/6Tqinn68pC0QavWqTu7t/CRxXwfGnS20/CTwZ2dAE4MlvDeatBW9XeG7jzn0l21M+31RfIYlIHNCjkHEge8K4WIcgInFGyT1OvP7DEdWWmTh/PTNWHrzVsTc/9Gbr+h15vDF3bdRiE5GGR8k9Tgzu0bbaHvwNL8/l0j/NLNn/3dTlAIyYMJ2fvDKffQVFUY1RRBoOJfc4899bTg+77J78snPS5OUruYs0FkrucaZH+2bMuWt0WGX/OvMr1gfPxANcVmoqYRFJbErucahd8zTu//oALji+a7VlR0yYXrL9+YadbN29P5qhiUgDoeQep749rCePXjyoxp874zcfltn/ZOUWJs5fH6mwRKSBUHKPc0vuG1Oj8rl5BWzauY+VObtZuDaXb/1pFje8PJdCvfgkklDCeUNVGrCmacksf+Ac/rd8C+t25HHnm4uq/cyJD0475FheQREtk5N4+dOvuO31hcy9azRtm6dFI2QRqQfquSeA1OQkTu/fkW+UGoO/7vQ+Napj4D3vkb1lDy98kg3Au4s3ArB0406e+mBlxGIVkfphsZoCJjMz07OyDlmxTyKo160VT1sQruwJ4+h/1zvsKyhm5YNjSU7Seq4isWZms6tYEa+Eeu4JLHvCOB7+v4G1/nyf2yexryA0Fl+gMXmRuKLknuCO7dYGgGZpydwwqm+NPltUap743LwC/vLxKt14FYkTSu4JrlXTVAC+dmwXbhp9ZK3rOfHBadzzn88Z+uA0CouKyyR+EWl4NObeCCxYu4N+nVrSJCWZnfsKOPae9wD461UnsmzTLpqnJXPr6wtrVGefjOZMu3lkFKIVkaqEO+auRyEbgQNDMwCt0lO5/owjOKZLK07u24GT+4ZWQ/z68V3pf9fksOtcmbOH3LwCWgd/GYhIw6Lk3gjdfFa/Q46lpybz6e2jaJmeyquz13DXvxdXW89/5q/n28N6RiNEEamjsMbczSzbzBaa2TwzO2QsxUIeN7MVZrbAzAZHPlSJto6t0mmalszlw3vx7+tOok9G80MW6S7tzjcXsW5HHv+Zv57Sw3u3vb6AeyZW/8tBRKInrDF3M8sGMt19SyXnxwLXA2OBE4HfufuJVdWpMff4sXnXPoY+cOhbreXNuWs0RcXOCQ9MBSpfQaqgqJj3Fm9i7MBOBGvzVmv7nnzMoE0zvTUrjVt9P+d+PvCih8wE2phZ5wjVLTHWsWV6WOUG/3JKSWIH2LmvgOJiZ+rnm+h169slM1L+/v0VXPf3Obwxdx179hdWVl0Zx/9yCoPum1Lz4EUaqXCTuwPvmdlsMxtfwfmuwJpS+2uDY2WY2XgzyzKzrJycnJpHKzHz2rUjePY7mTRPSw77M8fe8x43vjKPq18M/YX20fLQH34vzlgNwE3/nM+Q+ytP2FM+38SqLXvC/gUgIgeFe0P1ZHdfZ2YdgSlmttTd/1vTxtz9GeAZCA3L1PTzEjtDerYF4Pkrh3LRH2dwx9ijeGDSkmo/V3o64fzCYl6dvZZte/JLju0rKMbdDxmeKS52vv+ihu1Eaius5O7u64Lvm83sDWAoUDq5rwO6l9rvFhyTBDP08HZkTxhHcbGHldxL+9lrCyo8PuerHfRo14yMlk24Z+JiTjy8Hdf+bU6FZbftyaedZqsUqVa1N1TNrDmQ5O67gu0pwH3uPrlUmXHAjzh4Q/Vxdx9aVb26oZoY6jo5WWmZPduStXp7teVm3T6K9s3TSEnWC9bS+IR7QzWc5N4beCPYTQH+7u4PmNk1AO7+tIX+pn4SGAPsBa509yozt5J7Yti6ez/3/udz9uYXMXXJpnpte9n9Y2iSEv49AJFEELHkHi1K7onnpn/OY1jv9lw4uBtvL9zA9S/PjXqbn9x6Bl3aNI16OyINhZK7xFxhUTFH3PFO1NvRXPPSmGhuGYm5lOQknvtuJv07teKfWWt4bOryqLTT5/ZJHNe9DatydnPSER14/NLjSQ3G4wuLijU2L42Seu5SL/YVFPHKZ2s4pksrlmzYyeXDe+Hu/OCl2Yw7tjN9O7bknomL+TR7W5nPZU8YR15+ERty82jdNJUh90+tpIWynvtuJklmfPf5z3j0ouO4YHC3aFyWSL3TsIzEteJip6C4+JAbprV9OufTO0bxxpx1DO7Zlr98nI0ZnNK3A98c0p0kDelIHNGwjMS1pCSjSdKhT8JcPqwnL81cXeP6Kpob560FG5i0cCPXnNaH3hnNWb8jj+N7tK1VvCINjXruElfyC4tZvD6Xb/zhk6jU//S3hzBmQKeo1C0SCVogWxJSWkoSx/doy+J7z2b+3WfRO6N5ROu/5q+zKS52/vDBCnbvL+S5/63i6LvDX8Rkb34hD05awr6CoojGJVJT6rlLXCsqdoqKnXcXb4z6c/VfPjiWzbv206l12Vky8/KLeHjyUn56dj+e+98qHp3yBT8f059rR/aJajzSOKnnLo1CcpKRlpLE147rwry7R5c5F+k5aPrfNZlhD01j0bpcsrfs4eHJS3F3XpyRzV8+yeaPH67kqQ9WAvDw5KURbVukpnRDVRJGm2ZpnHREez5esZXP7zubZmkprN66h9Me+QCAVQ+N5ZRfvc/a7Xm1qj+/qBiAc5/4X8mxaUs2lTxT/8T0FWXKf7JyC5k925GWksSufQU0SUkmLUX9KakfGpaRhLJrXwFfbNpdMkUxhBJw/86t6NqmKXvzC/nNe1/w07P6Mfer7Tw2dTnpacnMX7OD3LyCqMT0/HdP4Mq/fEZmz7a8eu2IqLQhjYeecxepoSemLec3U76Iaht/v/pEvvXsLD2VI7Wm5C5SQwVFxby7eCMbc/dx/9s1m6u+Nq4/4wguPqE7BUXOkg07GXNMJ71QJdVSchepg/2FReTmFfDEtBWM6NO+0sVDIu2yE3tw3/kDajQR2r6CIvYXFNO6WWoUI5OGQsldJMIiuTBJVY7t1pqJPzq5zLE9+wt5/uNVvL8shytG9GLt9r307tCcMQM6c8ZvPuDLnD1kTxhXYX1Fxc4PXsrigsHd6N62GQO7ta6w3Pw1O9i+N5+R/TpG/JokcpTcRSIsd28BTVKTSE9NZtzjH7F4/c4y56fdfBrd2zYjLSWJrOxtXPj0jDq3efIRHTisVTqvzVlbbdlPbj2DlCQjv6iYbm2bsWLzLvLyi5m7Zjt3/3txSbkz+nfkxjP7cmy3NmU+f+CXV2W/JKRhUHIXiSJ358Mvcji1bwY79xWQm1dAz/Zl35atr/nsKzLlJ6cy+rdVr2F/IIm/t3gjvTOac+ajofIPXTCQTq3SOaZrKzq2TK+qComBiCd3M0sGsoB17n5uuXPfBR7h4KLYT7r7s1XVp+QujcWKzbtp0SSFYQ+FJi9787qTaN00ldN//UFM4xpzTCeuGNGLS/80s9IyWgil4YnGrJA/BpYArSo5/4q7/6gG9Yk0Ckd0bAEcOtxx/qAuHHlYS95bvBGA+Wtz6zWuyYs3MjlouzJ/nbma7wzvSWiZZIknYfXczawb8ALwAHBTJT33zJokd/XcRcr668zV/Pq9ZezYG52Xqepi0b1n89CkJfTr1JLTjszg9TnruPHMvkr6MRDRYRkzexV4CGgJ/LSS5P4QkAN8AfzE3ddUUM94YDxAjx49hqxeXfN5uUUS3bQlm7jqhYbf8WnfPI0J/3cso48+LNahNCoRmzjMzM4FNrv77CqK/Qfo5e7HAlMI9fIP4e7PuHumu2dmZGRU17RIozTqqEOT5fE92nDH2KP48xWZ/O3qE8OqJ9pj5Vv35PP9F7OYvnQTxcWxeTBDKldtz93MHgIuBwqBdEJj7q+7+7crKZ8MbHP3ih+mDWhYRqRyeflFmEF66qGrUcHBxxZXPTSWJ6av4IrhvfgsextXv5jFiD7t+fv3h7EhN4/hD00H4IrhPVm8fidXnnQ4j039guWbd0c85o9vPYOubZpGvF4pKyqPQprZSCoeluns7huC7W8AP3f3YVXVpeQuUntLN+5k3fa8Mr38lTm7GfWbD7nl7H5cd/oRlX62uNhZuC6XF2Zk8/qcdZjBj0f1BeCxqcsBuGPsUTwwqeZTMFw6tAc3n3UkHVo0AUK/pPbmF5KakkTLJikao4+AqCd3M7sPyHL3iUHv/jxCvfttwLXuXuWE1kruIpG3ZtteurZpGvYcNVnZ2+jUOp1ubZsBZV9kKi52PlqxhSue+7RGMRzVuRX/umY4b8xZy12lXp6697xjuGJELwAmL9rAh19s4aELBtaobtFLTCJSC+6OOxX+cjji9kkU1mFsvXXTVB7+v2Ppk9G85AWrLx8cW9LWjr35tGkW2QVWEpGSu4hE1Fdb9zLzy60M7Naa1k1TGTFhekTqveXsfmzdnc9zH6/ihe8N5bQj9bBFVbTMnohEVI/2zbjohO4c1bkVXdo0ZekvxwDwg9N616neR95dxnMfrwJg9urtdY5TQrTMnojUSnpqcslbtx8szWHZpl20bppapxWtvti4i5U5u+mT0SJSYTZa6rmLSJ29+5NTyZ4wjjvHHVXmePn96kxevJFRv/kQCC2eIrWnnruIRMz5g7ryxaZdFBQ53x7WgyM6tuSsoztx6iPv16ie8nPnf2d4T342pj8tmihlhUs3VEUk6p6cvpwmKclcPrwnq7bs4S8fZ3Pv+cfQ/67JNarng5+OZOSvP+De847h4hO6A/D+0s0M6tGGzq0rf4HK3Xn6wy85b1CXuH/RSk/LiEhc+MMHK3BukMIoAAAKV0lEQVQP3Vitiwe+MYBvDulOXkERTVKSyrzd+9XWvZz6yPsc06UVb99wSl1DjqloTPkrIhJxPxwZepu2rsn9jjcWcccbiwBo2yyV7XsL+OlZR3Ld6UdQWBwav9+zv7BuwcYR9dxFpEHILyxm574CMu+fGtV2sieMY/mmXfTJaBH2m7wNiZ5zF5G4kpaSRIcWTfj09lElc91EwxPTljP6t//lDx+siFobDYF67iLSIO3ZX8jPXl3A2ws3RKX+JIN3bzyVvoe1jEr90aIbqiKSMJZt3MXZj1W94HdtHde9DYe3b8a95w+gddPUqLQRSRqWEZGE0a9TSz68ZWRU6p6/ZgdvzlvPcfe+x5pte6PSRiyo5y4icSN7yx6Wb95NlzbpHNYqnZxd++mT0YK0lIP91EkLN/D+0s38a/ba2rVRbiFzgG178mnRJKVMO7GinruIJJxeHZoz+ujDOKZLazq0aMJRnVsdknDHDuzMI988jhvK3ZR98XtDw2rj8j/Potetb/PSjGz+8vEq3J3Bv5zCVS98FqnLqBfquYtIQtuxN5+iYqd9iyYsWpdLzu79XPl8+In6lfHDuPiZmUBohaq9+UX8+My+jHnsv+QVFPHhLadHK/QKRfwlpmBt1CxgXQXL7DUBXgSGAFuBi909u0YRi4hEQekFQAZ0DS3t/IfLBjP18028PnddtZ8/kNiBkqUH2zZPZenGXRGONLJqMizzY6CyRRWvAra7+xHAb4GH6xqYiEi0jB3YmUcvHlTrz99davnAI+98hy2790cirIgKK7mbWTdgHPBsJUXOB14Itl8FRplWwhWRBm7pL8eULDpSW/mFxbw4YzUQWhC8oQi35/4Y8DOgsgmWuwJrANy9EMgF2tc5OhGRKEpPTS5ZdGTpL8dw+9j+tarn8WnL6XXr2xx192Q+XrGlwjJLN+6kuA5r0NZUtcndzM4FNrv77Lo2ZmbjzSzLzLJycnLqWp2ISMSkpyYz/tQ+LLkv1JO/5ex+tarnsmdnHXJswdodjHnsI343bXmdYqyJap+WMbOHgMuBQiAdaAW87u7fLlXmXeAed59hZinARiDDq6hcT8uISDzYtHMf4x7/iL9/fxizV2/nttcXVvuZHu2a8VXwQtRt5/SnQ4sm3Pyv+QC8/sMRDO7RttbxRGX6ATMbCfy0gqdlrgMGuvs1ZnYJcIG7X1RVXUruIhKPJs5fzw0vz+Wa0/rw9Icra1VHRS9KhSvq87mb2X1AlrtPBP4MvGRmK4BtwCW1rVdEpCE7d2BnduzN56LM7owZ0Imv//7jWIdUIb3EJCJSB3v2F/LgpCX8bdZXYX/m0YuO44LB3WrVnqYfEBGpB82bpPDANwbWaKjln1lrohhRiJK7iEiE3DnuKAD6HdaSn43px4mHt6uwXF5BZU+VR47WUBURiZCrT+nN1af0Ltn//im9GTFhOjm7yr7BmlIPy/spuYuIRElqchLv/3Qk05duZvuefC4Y3JU/fbSKi0/oHvW2ldxFRKKoRZMUzjuuS8n+TaOPrJd2NeYuIpKAlNxFRBKQkruISAJSchcRSUBK7iIiCUjJXUQkASm5i4gkICV3EZEEFLNZIc0sB1hdy493ACpeyypx6ZobB11z41CXa+7p7hnVFYpZcq8LM8sKZ8rLRKJrbhx0zY1DfVyzhmVERBKQkruISAKK1+T+TKwDiAFdc+Oga24con7NcTnmLiIiVYvXnruIiFRByV1EJAHFXXI3szFmtszMVpjZrbGOpy7M7Dkz22xmi0oda2dmU8xsefC9bXDczOzx4LoXmNngUp+5Iii/3MyuiMW1hMPMupvZ+2b2uZktNrMfB8cT+ZrTzexTM5sfXPO9wfHDzWxWcG2vmFlacLxJsL8iON+rVF23BceXmdnZsbmi8JlZspnNNbO3gv2EvmYzyzazhWY2z8yygmOx+9l297j5ApKBlUBvIA2YDxwd67jqcD2nAoOBRaWO/Qq4Ndi+FXg42B4LvAMYMAyYFRxvB3wZfG8bbLeN9bVVcr2dgcHBdkvgC+DoBL9mA1oE26nArOBa/glcEhx/Grg22P4h8HSwfQnwSrB9dPDz3gQ4PPh/kBzr66vm2m8C/g68Fewn9DUD2UCHcsdi9rMd83+QGv7jDQfeLbV/G3BbrOOq4zX1KpfclwGdg+3OwLJg+4/ApeXLAZcCfyx1vEy5hvwF/BsY3ViuGWgGzAFOJPR2YkpwvOTnGngXGB5spwTlrPzPeulyDfEL6AZMA84A3gquIdGvuaLkHrOf7XgblukKrCm1vzY4lkgOc/cNwfZG4LBgu7Jrj8t/k+BP7+MJ9WQT+pqD4Yl5wGZgCqEe6A53LwyKlI6/5NqC87lAe+LsmoHHgJ8BxcF+exL/mh14z8xmm9n44FjMfra1QHYD5u5uZgn3rKqZtQBeA250951mVnIuEa/Z3YuAQWbWBngD6B/jkKLKzM4FNrv7bDMbGet46tHJ7r7OzDoCU8xsaemT9f2zHW8993VA91L73YJjiWSTmXUGCL5vDo5Xdu1x9W9iZqmEEvvf3P314HBCX/MB7r4DeJ/QkEQbMzvQuSodf8m1BedbA1uJr2s+CTjPzLKBfxAamvkdiX3NuPu64PtmQr/EhxLDn+14S+6fAX2Du+5phG6+TIxxTJE2EThwh/wKQuPSB45/J7jLPgzIDf7cexc4y8zaBnfizwqONTgW6qL/GVji7o+WOpXI15wR9Ngxs6aE7jEsIZTkLwyKlb/mA/8WFwLTPTT4OhG4JHiy5HCgL/Bp/VxFzbj7be7ezd17Efo/Ot3dLyOBr9nMmptZywPbhH4mFxHLn+1Y34SoxU2LsYSeslgJ3BHreOp4LS8DG4ACQmNrVxEaa5wGLAemAu2Csgb8PrjuhUBmqXq+B6wIvq6M9XVVcb0nExqXXADMC77GJvg1HwvMDa55EXB3cLw3oUS1AvgX0CQ4nh7srwjO9y5V1x3Bv8Uy4JxYX1uY1z+Sg0/LJOw1B9c2P/hafCA3xfJnW9MPiIgkoHgblhERkTAouYuIJCAldxGRBKTkLiKSgJTcRUQSkJK7NGpmdqOZNYt1HCKRpkchpVEL3qLMdPctsY5FJJLUc5dGI3iL8O1gbvVFZvYLoAvwvpm9H5Q5y8xmmNkcM/tXMA/Ogbm6fxXM1/2pmR0Ry2sRqY6SuzQmY4D17n6cuw8gNHPheuB0dz/dzDoAdwJnuvtgIIvQnOQH5Lr7QODJ4LMiDZaSuzQmC4HRZvawmZ3i7rnlzg8jtEDEx8EUvVcAPUudf7nU9+FRj1akDjTlrzQa7v5FsJzZWOB+M5tWrogBU9z90sqqqGRbpMFRz10aDTPrAux1978CjxBa4nAXoSX/AGYCJx0YTw/G6I8sVcXFpb7PqJ+oRWpHPXdpTAYCj5hZMaGZOK8lNLwy2czWB+Pu3wVeNrMmwWfuJDQLKUBbM1sA7Ce0HJpIg6VHIUXCoEcmJd5oWEZEJAGp5y4ikoDUcxcRSUBK7iIiCUjJXUQkASm5i4gkICV3EZEE9P8PU/C01Tai7AAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAEWCAYAAABollyxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3XecVNX9//HXZ/rO9g7s0qRIkSJNVFTsisZu0NgSNRiNiYnxF/s3mug3mvLVJGrsRo2KvaFiwa6AgALSO+zCsr3P7k47vz/uXVyaLMju7Mx+no8Hj7lz750758xjec+Zc889V4wxKKWUSlyOWBdAKaVUx9KgV0qpBKdBr5RSCU6DXimlEpwGvVJKJTgNeqWUSnAa9ErtQEQ+FpHL98NxlorI5P1QJKV+EA16FTdEZIOINIlIg4iUish/RCQl1uXaHWPMcGPMxwAicpuI/DfGRVLdlAa9ijc/MsakAGOAccAte/NiEXF1SKmU6sI06FVcMsZsBt4BDhKRdBF5TERKRGSziNwhIk4AEfmpiHwhIveISCVwW5t194lIrYisEJFjd/deInKpiCwXkWoReVdE+trrDxORChHpbT8fZe8zxH6+QUSOE5GTgJuAqfavkUUicq6ILNjhfa4Vkdc75ANT3ZoGvYpLdrhOAb4B/gOEgYHAwcAJQNs+9kOAdUA+cGebdWuBHOAPwCsikrWL9zkdK6TPAnKBz4DnAIwxXwIPAU+KSBLwX+BWY8yKtscwxswE/hd43hiTYowZBbwB9BeRoW12vQh4ah8+DqW+lwa9ijeviUgN8DnwCfAoVuD/xhjTaIwpA+4Bzmvzmi3GmH8ZY8LGmCZ7XRlwrzEmZIx5HlgJnLKL9/sF8GdjzHJjTBgrsEe3tuqB24B04CtgM3B/eyphjGkBngcuBBCR4UA/YEZ7Xq/U3tCgV/HmDGNMhjGmrzHmKqxWuhsoEZEa+0vgISCvzWuKdnGczWb7Gf02Ar12sV9f4B9tjl0FCFAAYIwJYf2iOAj4u9m7WQKfBH4iIoLVmn/B/gJQar/SoFfxrghoAXLsL4AMY0yaMWZ4m312Fb4FdsC26gNs2c3xr2hz7AxjTJLdbYOIFGB1/TwB/F1EvLsp505lMMbMAYLAEcBPgKe/v6pK7RsNehXXjDElwHtYIZsmIg4RGSAiR+3hpXnAr0XELSLnAkOBt3ex34PAjXbXCvaJ33PtZcFqzT8GXAaUAH/azfuVAv1EZMf/c08B9wEhY8zneyizUvtEg14lgosBD7AMqAZeAnru4TVzgUFABdYJ2nOMMZU77mSMeRW4G5guInXAEuBke/Ovsb4wbrW7bH4G/ExEjtjF+71oP1aKyNdt1j+N1e2jY+xVhxG98YjqbkTkp8DlxphJXaAsSVgnhscYY1bHujwqMWmLXqnYuhKYpyGvOpJeJahUjIjIBqwRPGfEuCgqwWnXjVJKJTjtulFKqQTXJbpucnJyTL9+/WJdDKWUiisLFiyoMMbk7mm/LhH0/fr1Y/78+bEuhlJKxRUR2die/bTrRimlEpwGvVJKJTgNeqWUSnBdoo9eKaX2RSgUori4mObm5lgXpUP5fD4KCwtxu9379HoNeqVU3CouLiY1NZV+/fqx/WSkicMYQ2VlJcXFxfTv33+fjqFdN0qpuNXc3Ex2dnbChjyAiJCdnf2DfrVo0Cul4loih3yrH1rH+A76jbNh1h8hGol1SZRSqsuK76DfPB8++zsEG2NdEqVUN1RTU8MDDzyw16+bMmUKNTU1HVCiXYvvoPckW4+hQGzLoZTqlnYX9OFw+Htf9/bbb5ORkdFRxdpJfI+6cdtBry16pVQM3HDDDaxdu5bRo0fjdrvx+XxkZmayYsUKVq1axRlnnEFRURHNzc1cc801TJs2Dfhu2peGhgZOPvlkJk2axJdffklBQQGvv/46SUlJ+7Wc8R30Hg16pZTl9jeXsmxL3X495rBeafzhR8N3u/2uu+5iyZIlLFy4kI8//phTTjmFJUuWbBsG+fjjj5OVlUVTUxPjx4/n7LPPJjs7e7tjrF69mueee45HHnmEH//4x7z88stceOGF+7UecR70futRu26UUl3AhAkTthvr/s9//pNXX30VgKKiIlavXr1T0Pfv35/Ro0cDMHbsWDZs2LDfyxXXQT+nqJmJQKS5AWesC6OUiqnva3l3luTk5G3LH3/8MR988AGzZ8/G7/czefLkXY6F93q925adTidNTU37vVxxfTK2pMmK9+bA/v25ppRS7ZGamkp9ff0ut9XW1pKZmYnf72fFihXMmTOnk0v3nbhu0fuSUwFobqgjeQ/7KqXU/padnc3hhx/OQQcdRFJSEvn5+du2nXTSSTz44IMMHTqUAw88kIkTJ8asnHEd9EnJaQA0BxpiXBKlVHf17LPP7nK91+vlnXfe2eW21n74nJwclixZsm39ddddt9/LB3HedZOcYgV9sHnXP52UUkrFedCnpFpBH27SoFdKqd2J66BP83sJGC/hZh1Hr5RSuxPXQZ+e5CaAFxPUPnqllNqduA76ZI/TCvoWbdErpdTuxHXQiwgt4tMrY5VS6nvEddADBB1JOMIa9Eqp2Lvtttv429/+Futi7CTugz7kTMKpQa+UUrsV90EfcSbhjuz/uSGUUqo97rzzTgYPHsykSZNYuXIlAGvXruWkk05i7NixHHHEEaxYsYLa2lr69u1LNBoFoLGxkd69exMKhTq8jHF9ZSxAxJ2MO7j7oF9b3kAkahicn9qJpVJKdbp3boCt3+7fY/YYASfftdvNCxYsYPr06SxcuJBwOMyYMWMYO3Ys06ZN48EHH2TQoEHMnTuXq666ig8//JDRo0fzySefcPTRRzNjxgxOPPFE3G73/i3zLsR90BuXH19090H/pxnLCAQjvHDFoZ1YKqVUd/DZZ59x5pln4vdbU6afdtppNDc38+WXX3Luuedu26+lpQWAqVOn8vzzz3P00Uczffp0rrrqqk4pZ9wHvXiS8ZmW3W6vawrRHIp2YomUUjHxPS3vzhSNRsnIyGDhwoU7bTvttNO46aabqKqqYsGCBRxzzDGdUqa476MXbzJJNNMc3PU9GptCUVrCkU4ulVKqOzjyyCN57bXXaGpqor6+njfffBO/30///v158cUXATDGsGjRIgBSUlIYP34811xzDaeeeipOZ+fcSSPug97pTcYphrqGXc930xyKEIxoi14ptf+NGTOGqVOnMmrUKE4++WTGjx8PwDPPPMNjjz3GqFGjGD58OK+//vq210ydOpX//ve/TJ06tdPKGfddN64k6yRrQ309eVmZO21vCkaIGtPZxVJKdRM333wzN998807rZ86cucv9zznnHEwnZ1Lct+jdSSkANNbX7HJ7UyhCS1hb9Eqp7ivug95rt+gbG3Z9O0Er6PfQR//+H6B4wf4umlJKdQlx33WTlWF111RWV++0LRI1BMNRHGKdEBGRnQ8QbIQv7rWWC8d2ZFGVUh1gt/+3E8gP7eqJ+xZ9eno6AJVVOwd9c8hqySebAJENX+76AI0V1mOL3mBcqXjj8/morKzs9D7vzmSMobKyEp/Pt8/HiPsWvaT1AsBZuWKnbU120F/sfA/n06/AjZvBvcOHFWgNer1LlVLxprCwkOLiYsrLy2NdlA7l8/koLCzc59e3O+hFxAnMBzYbY04Vkf7AdCAbWABcZIwJiogXeAoYC1QCU40xG/a5hHuSPYAizwDG1L4Pi0aBicDonwDWiBuA3lKGRMNWmO8Y9I2V1qMGvVJxx+12079//1gXo8vbm66ba4DlbZ7fDdxjjBkIVAOX2esvA6rt9ffY+3WoVflTGBpZhXntF/D+/4D9M66166aXtIb5LrpntEWvlEpw7Qp6ESkETgEetZ8LcAzwkr3Lk8AZ9vLp9nPs7cdKB58pqR14GlEjIA5oLIeK1QA4139MLtX0kCprx+banV/c2kffrH30SqnE1N4W/b3A74HWAenZQI0xpnXegWKgwF4uAIoA7O219v7bEZFpIjJfROb/0P61vF4HcEnoev4n9Y/Wio2fQ3Mt/d+9mCtdb34X9N/botegV0olpj0GvYicCpQZY/brQHNjzMPGmHHGmHG5ubk/6Fh9s/18Fh3J06V9KScTNnwOm+YiJsoYxyrSxJ7dclfdM9pHr5RKcO05GXs4cJqITAF8QBrwDyBDRFx2q70Q2GzvvxnoDRSLiAtIxzop22F6ZSSR5HYSjkb5MjKUH63/HEd6bwBGyvrvdtxV90zbFr0xkODjcZVS3c8eW/TGmBuNMYXGmH7AecCHxpgLgI+Ac+zdLgFaZ+15w36Ovf1D08GDXJ0OYfq0ifzt3FHMiozB0ViKmfcoAA5p89a76p5ptLuNomEIN3dkMZVSKiZ+yAVT1wPXisgarD74x+z1jwHZ9vprgRt+WBHbZ1TvDMb3y2JGdCJbvf2QYANbfAO332mXXTcV220vrWvm7H9/SWnd7kO/JRyhrF6/FJRS8WGvgt4Y87Ex5lR7eZ0xZoIxZqAx5lxjrLt/GGOa7ecD7e3rOqLgu9Iz3Yff6+HWhrMBeMF56rZtUYd716NuApXgt88Vt9TzbXEtCzZW882mXU+SBvDoZ+s5/v8+JaTTHyul4kDcT4HQlogwMC+F9yNjOaLlHp4MHEadSaLcpBH0ZOzcdRNqhmADZNoXXDTXUtNk3ai3vKHNXatKFsOdPaF6IwCrSuupbQqxpqyhM6qllFI/SEIFPcDgfGva4iKTT3VTmJWmN5tNDkFnCmUVFUSjbfrsW0/EZh1gPbbUUxMIAlDetuvm079CKAAbrflyttRYo3iWbtEhmUqpri/hgv7QAdn0SPOR4bfurH5z6DJuCl1OdcTH8vXFfLKqzZj91v75LLtF31JPTcBq0ZfVt8BLl8Ebv4LSpdZ2nzWB2pYa60tg6ZZddAUppVQXk3BBf+bBhcy+8RgKMpIA2OTqyzLTj1qTRJoEGPbWmbBourVz444t+jpqmqwWfXLFIljyEix8DqrWWttDAcKRKFvrWoO+DmqLIRzstPoppdTeSrigB6uvPjvFC0B6ktWyr4l4GSBbyK9fAqvfs3ZsKLUes+3ROW1a9CdWPAlOL0RD245bUV1NWX0LkajB73GydksF5v5DYMETEG6BhrLOqaBSSu2FhAx6gJxkDwBpPivoK8I+0iRgbSyzpzRutIO5TYu+timEjxYmhObBxCshvc+2Y5ZX1Wzrnz9yUC6+lkok2ADVG2D2ffDgpA6vl1JK7a2EDfrsFCvo/R4nXpeDqvB30xObytUQCVstcE8K9Y5UjNMLLfVUB4IUiNWlE80bBqPOIyTWsVqaGthsB/0ZBxeQLfbJ2EAlVK6zfiFEwiilVFeSsEGflWx13fjcVtA3kLRtm0SCUGUFs0nJY8Rt71GPH5rrqAmE6C3WCdtab0846vf8PPsJAEJNDdtOxB4xKIfxuVaom8aK766wDTV2VhWVUqpdEjboW1v0SR4nHpeTeuMHoMVYXTmUL4eGMiJJ1oRqlSEPwUAttYEQw/zWaJoyRx5RcfFVmYuA8RJqaWRzTYAMv5tkr4tj+1gfX3Nt6XdBH9SgV0p1LQkb9DmtQW+36Ouwgv6L6HBrh7IV0FBGiy8HgAaSKN5aSn1LmGH+GoLGyZZIOkXVAQLBCAG8RFsCjN3wKHe4n4BQM2OyrBO1ofqK70bwaNArpbqYuL9n7O5k2103SW4nXreDBmN13WxJGkRZZCt55cuhoZRA7iEAbDVZjK3+lizq6O+qosRkU94QJhi15sdpwUM0GODkwEx8tMAz5+DrMQIAX7AaIvZ3ZlCvllVKdS0J26Jv7brxeZx4XU6rDx7I6jOcb4KFRDZ9Bc01NLiyAPio1zSSTYD/dT9Gvimj2OSytqKBBRurrZmL3X4IBSg1GdYbbPgMyqw7K3pMC4TtOe+DgU6tp1JK7UniBn2bFr3H5WBptC/Lk8cz4JApLIgOwllvTZ9fiRXc/Ycfwr/CZ3KScx5Ztcsx6YU8N3cTz87dxJSDemLcSURbAqQQoD7Jvht70dyd31i7bpRSXUzCBn2Sx8mxQ/IY1zcTr8tBNWk8O+heBg8cTHHqqG37rWtOBuCIwTk8EzmWoHHiiAYZNHg4dc1hGlrC/OrYgRiXHx/NpBKgMe9g68WhAOEde7+060Yp1cUkbNADPPbT8Zw8oidel1XNtCQXIsKV559F2B4bv7TO6rsflJeKJy2XD6JjAcjvM5jzJ/Thool9GdIjDTx+0mnEIxHcPYaB2/qCqErqu/2baoteKdXFJHTQt9oW9PZVsiP65tOSNxKAb6o8pPlcOB3CiIIMnokca70obyh/PmsEfzrjIACcHj+5Ys1Rn5aRAzmDAGhI2+HmJiHto1dKdS3dJOidAKTZ894ASL9JtBg3K+q9ZNrTJRw5OIfVyeOIXrsKeo3e7hhOr58srBE4bn865B4IQDBr8PZvpl03SqkuJmGHV7bV2qJP9X1XXd/R13Hm53kEcZNhfwFcNLEv50/og8O58/efOyn1u/vPelO3Bb0zZxBh4wB3Eq5Is3bdKKW6nO7Rondv33UD4PClUp4yBIAMv9WiFxHcuwh5AF9ScpsnaZA3DAB/bh+qSaXZmw2eFA16pVSX0y2C3uNsPRnr3m59Xpo10Vmm373Ta3bk9ae0eZIKg06A854lffAkKk0aDc5M8CRr0CulupxuEfRet91H79u+pyo/zRpr39qi/z5Oj7/NAdPA4YQhp5Dic/MKxzIvcwp4/Br0Sqkup3sEvWvXLfp8u0Wf0Y4WPZ42XTfetO02vZd6BjO9J2iLXinVJXWLoM/we/C6HNudjIU2QZ/UjqB3fzfNMd7U7TYN75XOrOWlNBqfBr1SqsvpFkF//oTevPmrSduGWbbKS7W6blqHV34vt9114/KBa/v9bzttOJl+D9+WhzE6vFIp1cV0i6D3e1wMzk/daX3PdKuVntmOPvptQb9Dtw1AbqqXSw/vT0XQTbRFW/RKqa6lWwT97hw6IJu/nD2SwwZk73nn1q4b785fGAC9s5IIGC/RFm3RK6W6lm4d9E6H8OPxvXHtZuz8dlpb9L6dW/QAhZl+GtE+eqVU19Otg36v7KlFn+kngBdnWOe6UUp1LRr07dU6vHIXffQA6X43YZcfhwlDONiJBVNKqe+nQd9e21r0uw56AF+S3drXkTdKqS5Eg7699tBHD+BLtrdpP71SqgvRoG+vPfTRA6SkWbcl1LH0SqmuRIO+vdxJMPkmGH7mbndJS0sHoLamsrNKpZRSe6RBvzcmXw/5w3e7ObVgKABbVn3TWSVSSqk90qDfj8aMHkMDfjYu+RJjTKyLo5RSgAb9fuVxu2jMGk7PwAo+XlUe6+IopRSgQb/fZQ+ewFApYslG7adXSnUNewx6EfGJyFciskhElorI7fb6/iIyV0TWiMjzIuKx13vt52vs7f06tgpdi6vXwXglRFLd6lgXRSmlgPa16FuAY4wxo4DRwEkiMhG4G7jHGDMQqAYus/e/DKi2199j79d99BoNQEbN0hgXRCmlLHsMemNpHRjutv8Z4BjgJXv9k8AZ9vLp9nPs7ceKiOy3End1WQMI4yQ9sCnWJVFKKaCdffQi4hSRhUAZ8D6wFqgxxoTtXYqBAnu5ACgCsLfXAjvNAywi00RkvojMLy9PoBOXDgeVjmySW8piXRKllALaGfTGmIgxZjRQCEwAhvzQNzbGPGyMGWeMGZebm/tDD9elVDlzSQ+VxroYSikF7OWoG2NMDfARcCiQISKtN2EtBDbby5uB3gD29nSgWw1BqfPkkRmuiHUxlFIKaN+om1wRybCXk4DjgeVYgX+OvdslwOv28hv2c+ztH5pudvVQgzef7GgFdK9qK6W6KNeed6En8KSIOLG+GF4wxswQkWXAdBG5A/gGeMze/zHgaRFZA1QB53VAubu0gK8HXkLQWAEpidUtpZSKP3sMemPMYuDgXaxfh9Vfv+P6ZuDc/VK6OBX097AW6jZr0CulYk6vjO0A4ZRe1mNNcYxLopRSGvQdIppqBX2wSsfSK6ViT4O+AzjT8ggaJ+HqzXveWSmlOpgGfQdI8XnYarIwtdp1o5SKPQ36DpDidbGVLBz1Jd+t1KGWSqkY0aDvAMleF+UmHUeTPbVDQzncPwEWPPn9L1RKqQ6gQd8BUn0uykwm7ib76tgP/wQVq+Ddm6BW++2VUp1Lg74DtLboPaE6KFkMXz9l3VQ8GoZP7op18ZRS3Ux7roxVeynF66KcDOvJircAA8fcal0pW7EmpmVTSnU/2qLvAMkeJ+Um3XpSNAfECRl9ICkTmqpiWzilVLejQd8BXE4H9c4sACJF89lssnllUSn4syDQrSbyVEp1ARr0HaTRkwOAM9TAhkgu176wiHUBHzRV61BLpVSn0qDvICFfNlGsOyhWuHqQlexhabXTOiHbUhfj0imluhMN+g6SnOSl3mH104fT+zKqMJ0VtW5ro3bfKKU6kQZ9BxnXN4uSSBoAntz+jOqdwYo6e5BToDqGJVNKdTca9B1kyogelEWtFn1WwWBG9c6gKppqbdSRN0qpTqRB30HG9Mmk3p0NQMEBQxldmEE1KdbGgAa9UqrzaNB3EIdD8BUcRAk59C7oTWayh+SMPGuj9tErpTqRBn0HOuKi2/BcMx+n0/qYs7NzieDQrhulVKfSoO9AHo+b7MzMbc9z0pKoIwVqimD2/RAJxbB0SqnuQoO+E+Wmeqk2KZhvX7Bmslz1bqyLpJTqBjToO1Feqo8qk4KYqLVizfuxLZBSqlvQoO9EeXaLvpVZ/YFOh6CU6nAa9J3I6rqxxtK/HDkCqSuG8hUxLpVSKtFp0HeivFQvK00hmyWf/wudY61cPiO2hVJKJTwN+k6Ul+bjscgpHNn0NzaTy5bMCdbdp6KRWBdNKZXANOg7UbLHSZLbSQQnAHOyT4faTbBm1nc7la+EcEuMSqiUSkQa9J1IRMhL8257/qljAiTnwTdPWStKFsH9h8CCJ2NUQqVUItKg72R5qVbQi8DmujAMOQXWfgThIMz6E2Bg66LYFlIplVA06DtZXqoPgJEF6Wyta4ZBJ0CwAT77mzWuXhxW941SSu0nGvSdrCAzCb/HyYT+WZTWthDtdwQ4PfDJ3Zj0QhZln0ykbIWOr1dK7Tca9J3syqMG8OIvDqUw008wEqUq7IG+hwOwYuQNvFiShzNYD/UlMS6pUipRuGJdgO4mM9lDZrKH4uomALbWNpN0yK8JpA3nyeqRrI+WWjuWLYe0XjEsqVIqUWjQx0jPdKuvvqS2mT9/6Wf22sPwuLbgNwXWDuUrYeCxMSyhUipRaNdNjPSwg/6+D1fzxZpKslO8NIei5PUooMqkEi1bHuMSKqUShbboYyQv1ceVkwfw0CdrGZSXwktXHsa89VVUB4J8/fpAjlrzIQ5jrHGYSin1A2jQx9D1Jw3hnLGFpHpdpCe5OW5YPguLavhvdALH1T8EW76GgrGxLqZSKs7tsetGRHqLyEciskxElorINfb6LBF5X0RW24+Z9noRkX+KyBoRWSwiYzq6EvFsQG4KeWm+Ns+TeS8ylog4YelrMSyZUipRtKePPgz8zhgzDJgI/FJEhgE3ALOMMYOAWfZzgJOBQfa/acC/93upE1iqz01SWjZLPAdjlr8Z6+IopRLAHoPeGFNijPnaXq4HlgMFwOlA66QsTwJn2MunA08ZyxwgQ0R67veSJ7BLD+/PGw1DkOr1UF8a6+IopeLcXo26EZF+wMHAXCDfGNN6Vc9WIN9eLgCK2rys2F6347Gmich8EZlfXl6+l8VObFccNYBewycB0LTxqxiXRikV79od9CKSArwM/MYYU9d2mzHGAHt1zb4x5mFjzDhjzLjc3Ny9eWm3MGDkoYSNg6qVs6G5TqdEUErts3YFvYi4sUL+GWPMK/bq0tYuGfuxzF6/Gejd5uWF9jq1F4b17cEK04fktW8R/ssg6ub8J9ZFUkrFqfaMuhHgMWC5Meb/2mx6A7jEXr4EeL3N+ovt0TcTgdo2XTyqnfJSfaxyDSYjsAFXtJl1n78Y6yIppeJUe1r0hwMXAceIyEL73xTgLuB4EVkNHGc/B3gbWAesAR4Brtr/xe4earJGAbDVZHJAw9csWF+2h1copdTO9njBlDHmc2B3l2fuNBmL3V//yx9YLgU0HXgm07aEOXZQOlM33sZ7789k7LSLY10spVSc0bluurAjhxayLP1IJp1wDlGEw7b8h+iX98OmObEumlIqjugUCF3YiMJ0Pr/+GADKcg7hqIo58N4Ca+NVcyFvSAxLp5SKF9qijxPBn7zM4OYneWPCM9aKzQtiWyClVNzQoI8ThVkp5Gak8U5FHrj9sHVxrIuklIoTGvRx5NAB2Xy+tppgzjAoWQzzH4evHoFI+LudKtbotAlKqe1o0MeRKycPIBw1fFLXE7PlG3jrOnj7OnjmHGsHY+Cp02DmDd9/IKVUt6JBH0cG5KZw8ylD+aCmBxJuAhMlNOoiWPcRVK6FitVQt5lwybexLqpSqgvRoI8z50/oQ0PmUAAWJx/K8fPHAxBZ8TZNqz8GQKrWQjgYqyIqpboYDfo443QIp51wAi+Ej+LWutMZfOBwlkf7UL/oDYq+fs/ahyiBratiXFKlVFehQR+HThhRiPPMB/jHNRfxz/MP5lPHONLK5tOz4nPWiTWf3Kql82NcSqVUV6FBH4dEhLPHFtIvJxmf20nZ4AuZHRlKMs14j/odUYSytYtiXUylVBehQZ8Azpo8jnsL/sbyy9dQMPlnVLl7YMpW0ByKxLpoSqkuQIM+AQzvlc6LvziM4YU5ADjyhjAwup6X5m8C4NMVW1lZUvd9h1BKJTAN+gSUOeZMBjhKyJh1HcE3rmXMc6PZ+Nw1sS6WUipGNOgTkIy5mE39z+PU8AfIN0/RhJshdbPb9+IPboenz+zYAiqlOpUGfSISoeDCB/h50j0c1PQIj4ZPoQ8lrFy7joc+XE7k3ZuhesPOr4uEYMF/YO1H0FLf2aVWSnUQDfoE5XQ6OXrycbTgoTxzNADTX36RNbMexzn7Pnj/f6wdazbBynes6RPWfwJNVYCx5tJRSiUEnY8+gZ01poAks6pzAAAalklEQVQVW+uYOnocLY/fTM+6RRztXEgEB85lr8Oz58GqmYCByTdaUyi4kiDcBFu+gX6Hx7oKSqn9QFv0CczndvLH0w9ieN98VjsHcrHzPQY5NnN76CIa8BNe8yHzCy/hzcih8PGfYclLrOxxCjXuPCvolVIJQVv03cTi/DNpLnmNkUecTl/X2fxhySRmFzWxZU0OB/c6g2Wl0xk24mD+sq4fN7ds4oTNX1utgGAjNNVAag9wOGNdDaXUPtCg7yaO/8lvqQn8Ek9+KpcB5shBPPLZOoqqmvjDj4Zx8eNJPLWshsZghMXO/pxU/RUUzYMXLoL6Euh3BFzyJsju7hOvlOqqtOumm8hN9TIoP3XbcxFh2pED+NMZB+FyOjh3XCGNwQhup/BG9DCaXanw+IkQqCQ86gLY8Bls/DKGNVBK7SsNegXAScN7kup1cdTgXDw5/bgv83qiCPWTbubiknOoJxnz5b9g/afb39FKKdXliTEm1mVg3LhxZv58nW0x1paX1JGd7OEv767kpQXFpBDAeFJpDEa4xfU0l7vesXb80T9h2OkQaoK0nrEttFLdmIgsMMaM29N+2qJX2wztmUZemo8jBllz5lxw1EF43U6OH5bPE65zebHHtZDeB1a8Ba/8XK+gVSpO6MlYtZPTRvVi0sAcslO8/Pa4wbidDm55zcMfFiZzyrgG/AufgGgIEGtUjic51kVWSn0PbdGrnYgI2SlewBqL73QIl03qjzFw28redsgDGChfGbuCKqXaRYNetcvAvFTuPW80r1X1pdqkssz0szaULW//QbrA+SCluiMNetVuJw7vwZxbTqbi/Hc4v+Umwg6vNeRy+gWw9FVrJ2OgsfK7F0VCEG6BkkXw9yHw2d/3LvAjIajeuH8rolQ3o330aq9kJXvIGjKCEQMDrC4uYMjCZxAMrJgBm+ZCKACLnoNpn0DuEHjmHKvV70uHQAXM+iMEquDEO3c+uDHQVA1JmTDjN5AzGNZ9AmtnwRWfQf6wzq+wUglAg17tk1tPHcayBwoZ6liHGXAskjMI5v4bgCiCef82nP0Ph3UfE3Sn42kohanPwLqPYPZ91hdCMAAFY2DMxeD0wEuXwur34bxnrOmSWzk98P6tcOHLsamsUnFOg17tkwN7pFI5aCys/ZR5BRfxny29uWSAh83FG1nb4OX/rXkB1rzLF46xXNt8Je+cl8WjGwu4dNJR5Gz5BuY/Acm5sHi61Wr3+GHZa9bB3/i19XjU9eBJAXHAezfDoukw6rzYVVqpOKVBr/bZ+LOu4dZ/+XjuAw/haClvMxGYSJorzIBk4YChY7n08zxa8HDuTCdry9cSCEb4wyUzKK+uJerLpMfyJ2DmDSAOto69jtDyt+hduxwKxvJ65iVUNAS57JACazrl16+GaBhGna8TrCm1FzTo1T5zJ2cy+fSf8fST87nk0L5srWumujHE5CG5XDvTRY/FPvKzHKR4XSwrqcPjcjB93ia+3VzLgo3VeJwO3vvtxfQ7MwuyB3LTB1EKa8v5o3s55sAp/PXdlRRXN5GV7ObMqf+1+vtf/yUsfh4ufAWc7lh/BErFBQ169YMcOzSfWb87iv7ZyTgcgjGG6kCIWcvLCEei/OKoATSFIvxxxjL+df7BXPz4VyzZXMt1Jwzmvo/W8I9Zq7ln6lQ21zTx8coPSTZHMNFfwpDCMyiuXkmq18X1L38LZ49gyJSXKFzzLKkf3ghvXweHXwOZ/XVGTaX2QOe6UZ0iGjU4HMJbi0vom+3noIJ0/vzOch7+dB2Zfg9Oh1DR0MLvTxzC3TNXMKZPBl9vquHNqydx+5tLmb+xetux/uB+ip85ZwJgkrKQQ38Jk64Fh44WVt1Le+e60aBXMVMTCHLLa0tI8bpYV9HIkB6p3Paj4Zz3yBy+Wl/FgNxkZv1uMi3hCDMWleBxOSita6ayoYXS1QvwlH7DVT1X0afiUygYCwdfBON+tvMbRUIw63YYdgYU7vH/hFJxY78FvYg8DpwKlBljDrLXZQHPA/2ADcCPjTHVIiLAP4ApQAD4qTHm6z0VQoNetVVUFWDKPz/jool9+f1JQ3a5jzGGnz81n09Xl/PRUesoWP0MlC2Di1+HAyZbM2s63OB0wcwbYc4DkHMgXPmltU6pBLA/g/5IoAF4qk3Q/wWoMsbcJSI3AJnGmOtFZArwK6ygPwT4hzHmkD0VQoNe7ag2EMLvdeJ27r47pry+hSn//Iz65hAHZnt4qPYKcnNycGYPsMbjJ+fAwGPh66egcAIUf0XJwPPoOfI4GHGO9u2ruLffpik2xnwKVO2w+nTgSXv5SeCMNuufMpY5QIaI6ITlaq+l+93fG/Jg3TVrxq8mccyQPHB5ubvlHJzly2HTHBh7iTUq5+unYMI0IpfMYKEMpeea6fDK5daQzjY3UKnfuoZvHv4FzbVlHV01pTrdvv6GzTfGlNjLW4F8e7kAKGqzX7G9roQdiMg0YBpAnz599rEYqrvLT/PxwAVjAbj0CQ8XburN/VdezLyiAK7CXzA5twF6jmLu2gp+2nQDyTRxZ857TJn7IKz5wLoXrj8LM/dpDg6WsemVFvr87IkY10qp/esHD1MwVt/PXp/RNcY8bIwZZ4wZl5ub+0OLoRS/OX4wc1r6M+X+eVz+1HwufW4Fr27N5ptN1bwwrwi3x8cvTzmEqyrO5rMx9xL2ZdKy5HWin99LoCXMjMhE+mx8xfpF0FagCmqL230LxYaWMBsrGzughkrtm31t0ZeKSE9jTIndNdP6e3cz0LvNfoX2OqU63MjCDJ742XiueuZrThreg+KaAL99ftG27WcdXMClh/fng+WlXPaVk0j0WiJRgxDF43IwItfF2Jqr6Tn9AjjhT+DPoTp9KL5/jyOJZkgrhGNvtaZhqC+1pnDYNBtWvWNN1+C1br4+47l/k7bhbXpe/yKepJRYfRxKbbOvQf8GcAlwl/34epv1V4vIdKyTsbVtuniU6nBHDMpl/i3H4XE6qGsO88WaCtxOB5urA5wwvAcOh/CXs0fxP28sYVjPNI4ekkem34ND4P1lpfxk5vW85boD/2tXAlCZMpb+poWnM67gopT58OoVULcFProTRpwLG76A2k2w4m0Y+1PMiHM5bsPfyaGa6ucvx3PWPXpfXRVz7Rl18xwwGcgBSoE/AK8BLwB9gI1Ywyur7OGV9wEnYQ2v/JkxZo/DaXTUjeoKFhXVcPr9X5BOA32llIf9D9AjUsI7HM7VwauZfd1E0p6YjK9+I8bhRlrvtHXMrYQWv4K7Yilhlx9XOMAbkUM5zTnb3n4LjPgx1G+FPodY0zHriB+1H+gFU0rtpXAkyrSnFzC+Xxbl9S1sXvwhtzkfo+T4f3PWi5VkJ3s4ILCYO92PUXHMXzlsw79pScrlDu/veH5eESOjy3jIcy/Lon241nc7J2aVckfO+7D0VSLiBhPBTL4R1/zH4OAL4JhbrUnaVr8PSRmQ2gMayuDbF2HjbDj/WcjsF+uPRXVhGvRK7SfRqOHQu2ZR2RDk9tOH858vNhA1hpeuOJTzHp7D2opGzh3Xm3F9M7n5xa84IDeF8QN78eKCYhbdMhnH61fx/rJS+kSKGObYCN50aKmFQ68Gtx8+/cv2b+hwW1MzDzgaDrkC0vtAzsDYVF51ae0Ner1EUKk9cDiEBy4YA8DYvllkJ3v5xX8XMPGuD4kaw5OXTuDwgTkARKKG7BQPxsCTszdy4RPfMKbv7/l301rGpDcypeUtJp93O00f/JkRs+8jipPmAaewpOdZDE1pJDUjF3pPgIXPWjdbWTXTujvX+c9D30O/K9SyN2DzfJh8E0Hx4AnXg8sHLm/7KlWxBubcb00bUTBmf39kqovRFr1Se8kYwwMfr6W8voXjh+VvC/kd95k+r4i7Z66gJhDigNxkHr5oLFMfmkNlYxAXYV5IuouB0fUc1/I3ysgkP83LgxeOpUe6j9fmb+DyyPNEM/rhmX0vUr0eCsbBKX+Hljp4+iyIhgjmjeThrYO5wjMTd1IqTPg5ZPSFmo3WXbocLsjoA40VMOBowi1NyOATcc59wLrbF8AZD8Lo83df4ZJF8Mav4OS/WucYVJehXTdKdQFNwQizVpRyQE4Kw3qlsaWmibveWcHxw/L50fAcVm3cxKwi4YDcZO54axmVDUFyUrxsqgowrm8m6yoa6ZvUxN8HLaXPqidwNZZaB84aAEdeR+3bt5EeLGWxcxgjeqYgxV9te+81/tF4k/z09jTS7EjGs3kOYSO4xOAgyr/Mj7kody0Z9Wvg0ncgexC4PN8VfuOXsOQV685fjeXWxHGXz9r+RHKo2foy0fmDYkKDXqk4U1bXzE8enUtRVYALDunL41+s58D8VCobg1Q0tJBBPX/v/SXBpGxWZB1Lv779uO31pfR2VbO0IZme6X4CteWMzQ4xadRgbv+wAqdDuPTwfjz55UaSHEFG9EjmN6U3kiX1nBK6m15U8F7SjTgjLTQ5U/ENn4KpL+OlhhGcUfUobjFIegEM/RF8fg8c8TsYfQGk9oQP74D5j4MvDSbfuOuZQ1WH0qBXKg41toSpDgQpzPSzvKSO/jnJNAUjLCyq4Ys1FTz6+Xo8LuuC9mA4CsCzlx/CLa8vIRiOcubBBTz3VREVDS0ckJNMXXOIioYgxw3N544zDiIz2c2p//iUspo6nr1yMr+ZvpC0+tUMCq7gSMcijvWtoJkk0kOl1Jhk/jX4PxxzyBgi4RCTvrwMx6YvME4PktkfKlbByKlQtQ62fAO/WWx9Aew4dDQaher1kN57+18M6gfToFcqAW2paSLT78HlFNaWN1ATCDHxgGyaQxFcDsHldLBsSx2/fX4hN58ylGSvk5VbGzhvfG8cDiuAy+qaqQoEGdIjjeUldZx+/xf0zfKTlexh7voq3IT51wFzKE8fyR8XpxOKWBmRm+olrXkLNzufZrJrMc0/eoh5SYfTI7KVwS8chfQ9jEjZchyH/Ro5/NfQXAtfPQLzHoGGUqvr56LXrF8Aar/QoFdKtcuKrXXkpngJRw3PzytiYF4KJwzLx+V0UN0Y5NvNtQTDUZ79ahOZfg/vL9uKy4SoDgqt8fGf1IeYHPqEcpNOrtRat3gMVFonjgedAIXj4ZO7YcAxcMGLsa1wAtGgV0p1iDnrKnnok7WMLMxgfL8siqoD/PvNLzgqOoe33MdzQuhjruu/geysTKITr8bRcyRfrKnkoPWPk/HlnXDhy7D2I+saglUzrSGhZz8K4rTeoGyZdaFY7oExrWc80KBXSnWaBRureWtxCVdOHsD5j8xhXXkDfbOT2VzdxJCeqSwuriXb3cxs769xRwJgQIgSzByIp7EUgvXbH9Dls7p52l47oHaiQa+Uion65hB/fmcFGysb6Z3pZ+76Ks48uID5G6s5ZN19XOl6gyuDv+GT6EgyUlM4vTBAyvq3OXrMQRxUkGbNEjrzeqhcA/4cOPAk6DUG+h+18xXCDWWw+Wtrn+oNkFZg3XCmm9CgV0p1KU3BCGfe9xn1ZRs4/rDxnDWmgPMfnkNjMELfbD8bKwOcODyf8yf0oWrrRpKXv8gobwn5JR8iwQbwpsHJf7HmAiqeZ10ctvFLa6roCdNg3mPWhV95w2D2AzD+UmvOoOr11nmCk/4c649gv9OgV0p1OcXVAT5ZVc754/vgcAiLi2toaAkztm8mD368jkc+W0dDi3WDF6/LQUs4Sm6ScOdRfk5Y8AuoLyGcnE+Fu4AeNV9bB03taa0XN04TRpxucHqt7qD03pDWC4rmwoWvWPcQbqtkMXxxL5z4v9akcnFGg14pFXcCwTBfb6zB73VyUK90PltdzkOfrmPehiquPihC3tZP+HP5YQRx8Vz6AwwvzGZGr1/Tc8mD3Ln1EF7w3kGyR3hz0qucMsiPI3cwEg3jeGgShINw2XvgTrJmDS36Cl7/JTRVWXP+nH5frKu/1zTolVIJoTkU4bfPL+STVeVkp3i4eGI/mkIR/u/9VaT6XNQ3W78AJg3MoWHtHMQhfBMZQF6ql5qmEMFwlJ/2reSWyhtoiTrxmwbEWBebkd4Heo+Hpa/CxW9AuAW2LraGgwL0PqRLX+Sls1cqpRKCz+3k3xeO3W5dMBzl5a+LqQmEeOGKQxGBsX0yuexJobIxyD2H92PGohL65STjEHjiC2Gx+X/c5H6WOdGhVJh0tpBLU/LRHJrkYVrSp7iePHXnNz/wFJj6X2u55Bvr/sEf3Ab+bDjjAWtK6ZS8Ln8jGW3RK6XiUmldM8ZAj3TftnXGGGQXofvZ6nI+XVXOz488gFnLy3CKsL6ykU9WlrOytB5/tJHL0uexuQE+jozk70c6OTJpA3xyFww51Rrd0zphXGpPaKqGcLP1PL0PHHAUHDgF+kyEUBOkF1jbIiGIRsDt26lM+4N23SilVDuU1Tfz0oJivt5YQ99sPyu31jN3fSWnjezFj0rvZ0LtTBxOFx/0uJxqZw7rk0dxbI9mRjbPY8nWJga3fEtW2Rykpe67g46+wJrxc/1nVuv/6q9g67eQP3zbTeT3Bw16pZTaB7VNIe56ZwWvfbOZFJ+L2qYgoXCEJI8bn9tJKBKlvjm83fmBA7I8XN1nI4WRYtwNJRxcMp2QPw/noONwLHoW+h0BGz6D3CFw0l1Qsdqa/tnpgUN+YV0HsA806JVS6geIRg0Oh1DfHKK8voV+2ck4HEIwHOWOt5bx9aZq7jhjBEVVAR7/Yj2LimqI2nGaTS11JFOQncZLafeSU/Ixm5MOJC+yFXew1topf4Q1/cPh18Cw0/apjBr0SinViVrCERpbIqT6XGytbWbBxmru/2gN4fLV3Oh/g5sbf0xYXNw4OsRZhx7ICsdgHvhkLVccOYBRvTP26T111I1SSnUir8uJ12VNzNY7y0/vLD8nj+jBLa9mcMXXPfnfs0awcFMNv59fxG1LqwgEvyDV6+LE4T32OejbS1v0SinVgYwx1DWFSfdbc/DMWLyFL9ZUMrRnKqePLiA9ad/n5tEWvVJKdQEisi3kAU4d2YtTR/bq1DI4OvXdlFJKdToNeqWUSnAa9EopleA06JVSKsFp0CulVILToFdKqQSnQa+UUglOg14ppRJcl7gyVkTKgY172C0HqOiE4nQ13bXe0H3r3l3rDd237vta777GmNw97dQlgr49RGR+ey71TTTdtd7QfeveXesN3bfuHV1v7bpRSqkEp0GvlFIJLp6C/uFYFyBGumu9ofvWvbvWG7pv3Tu03nHTR6+UUmrfxFOLXiml1D7QoFdKqQQXF0EvIieJyEoRWSMiN8S6PD+UiDwuImUisqTNuiwReV9EVtuPmfZ6EZF/2nVfLCJj2rzmEnv/1SJySSzqsjdEpLeIfCQiy0RkqYhcY69P6LqLiE9EvhKRRXa9b7fX9xeRuXb9nhcRj73eaz9fY2/v1+ZYN9rrV4rIibGp0d4REaeIfCMiM+zn3aXeG0TkWxFZKCLz7XWx+Vs3xnTpf4ATWAscAHiARcCwWJfrB9bpSGAMsKTNur8AN9jLNwB328tTgHcAASYCc+31WcA6+zHTXs6Mdd32UO+ewBh7ORVYBQxL9Lrb5U+xl93AXLs+LwDn2esfBK60l68CHrSXzwOet5eH2X//XqC//f/CGev6taP+1wLPAjPs592l3huAnB3WxeRvPeYfRjs+rEOBd9s8vxG4Mdbl2g/16rdD0K8EetrLPYGV9vJDwPk77gecDzzUZv12+8XDP+B14PjuVHfAD3wNHIJ1JaTLXr/t7xx4FzjUXnbZ+8mOf/tt9+uq/4BCYBZwDDDDrkfC19su566CPiZ/6/HQdVMAFLV5XmyvSzT5xpgSe3krkG8v767+cf252D/LD8Zq3SZ83e3ui4VAGfA+Vqu0xhgTtndpW4dt9bO31wLZxGG9gXuB3wNR+3k23aPeAAZ4T0QWiMg0e11M/tb15uBdkDHGiEjCjnsVkRTgZeA3xpg6Edm2LVHrboyJAKNFJAN4FRgS4yJ1OBE5FSgzxiwQkcmxLk8MTDLGbBaRPOB9EVnRdmNn/q3HQ4t+M9C7zfNCe12iKRWRngD2Y5m9fnf1j8vPRUTcWCH/jDHmFXt1t6g7gDGmBvgIq8siQ0RaG1tt67Ctfvb2dKCS+Kv34cBpIrIBmI7VffMPEr/eABhjNtuPZVhf7hOI0d96PAT9PGCQfabeg3WS5o0Yl6kjvAG0nlG/BKv/unX9xfZZ+YlArf3T713gBBHJtM/cn2Cv67LEaro/Biw3xvxfm00JXXcRybVb8ohIEtZ5ieVYgX+OvduO9W79PM4BPjRWB+0bwHn26JT+wCDgq86pxd4zxtxojCk0xvTD+n/7oTHmAhK83gAikiwiqa3LWH+jS4jV33qsT1i086TGFKwRGmuBm2Ndnv1Qn+eAEiCE1ed2GVZf5CxgNfABkGXvK8D9dt2/Bca1Oc6lwBr7389iXa921HsSVr/lYmCh/W9KotcdGAl8Y9d7CfA/9voDsAJrDfAi4LXX++zna+ztB7Q51s3257ESODnWdduLz2Ay3426Sfh623VcZP9b2ppbsfpb1ykQlFIqwcVD141SSqkfQINeKaUSnAa9UkolOA16pZRKcBr0SimV4DTolbKJyG9ExB/rcii1v+nwSqVs9hWc44wxFbEui1L7k7boVbdkX7n4lj1H/BIR+QPQC/hIRD6y9zlBRGaLyNci8qI9R0/rPON/seca/0pEBsayLkrtiQa96q5OArYYY0YZYw7CmmVxC3C0MeZoEckBbgGOM8aMAeZjzaveqtYYMwK4z36tUl2WBr3qrr4FjheRu0XkCGNM7Q7bJ2Ld8OILe3rhS4C+bbY/1+bx0A4vrVI/gE5TrLolY8wq+3ZtU4A7RGTWDrsI8L4x5vzdHWI3y0p1OdqiV92SiPQCAsaY/wJ/xbq1Yz3WLQ4B5gCHt/a/2336g9scYmqbx9mdU2ql9o226FV3NQL4q4hEsWYRvRKrC2amiGyx++l/CjwnIl77NbdgzaIKkCkii4EWrNu9KdVl6fBKpfaSDsNU8Ua7bpRSKsFpi14ppRKctuiVUirBadArpVSC06BXSqkEp0GvlFIJToNeKaUS3P8HUdRCNdmmIAsAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# plot loss\n",
    "plt.figure()\n",
    "plt.plot(losses)\n",
    "plt.title(\"Total loss\")\n",
    "plt.xlabel(\"step\")\n",
    "plt.savefig(loss_fig)\n",
    "plt.show()\n",
    "\n",
    "# plot perplexity\n",
    "plt.figure()\n",
    "if len(perps) > len(steps):\n",
    "    perps.pop()\n",
    "plt.plot(steps[5:], perps[5:], label=\"train\")\n",
    "if dev_source_data is not None:\n",
    "    plt.plot(steps[5:], dev_perps[5:], label=\"dev\")\n",
    "plt.title(\"Perplexity\")\n",
    "plt.xlabel(\"step\")\n",
    "plt.legend()\n",
    "plt.savefig(perp_fig)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Check trained ECM model: internal memory and external choices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Norms of final-state internal memory:\n",
      " [4.9488392e-08 2.2703186e-07 1.4774101e-07 1.0272628e-06 9.4477564e-08\n",
      " 3.0802039e-07 2.0407163e-07 1.7523462e-07 2.1692833e-07 1.7939273e-07\n",
      " 9.5731451e-08 1.2859005e-07 6.2726293e-08 3.0073764e-07 1.8602023e-07\n",
      " 6.4615580e-07 3.0753051e-07 4.0815900e-07 1.1175257e-07 5.0409494e-07\n",
      " 1.7497564e-07 3.8119015e-07 7.7888693e-08 3.6049221e-07 2.1327084e-07\n",
      " 3.3154348e-07 7.8305739e-08 2.3314762e-07 2.4495864e-07 2.6986237e-07\n",
      " 2.4019579e-07 9.2927522e-08 1.4504516e-07 2.4237568e-06 9.9325121e-07\n",
      " 1.7717542e-07 1.2133442e-07 5.7879942e-07 9.8482133e-07 1.7662971e-07\n",
      " 5.3112819e-07 4.9701134e-07 1.8288803e-07 8.4730154e-08 1.9798597e-07\n",
      " 1.0218571e-07 2.3952785e-07 4.2140226e-07 4.7942365e-07 1.5186853e-07\n",
      " 3.3082961e-07 6.9179924e-07 3.2807799e-07 4.5293552e-07 7.4886493e-08\n",
      " 2.1518625e-07 1.0240418e-07 1.9807288e-07 7.3894519e-08 2.8741323e-07\n",
      " 2.3738041e-07 1.0459391e-07 1.8776549e-07 1.3394013e-07]\n"
     ]
    }
   ],
   "source": [
    "(cell, train_log_probs, alphas, int_M_emo) = train_outs\n",
    "\n",
    "M_norms = sess.run(tf.norm(int_M_emo, axis=1), feed_dict)\n",
    "print(\"Norms of final-state internal memory:\\n\", M_norms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA60AAAEICAYAAAC06xKrAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAHSxJREFUeJzt3XmYZmdZJ+Df01t2kgAhhCQQkAABFAJNRGUEGZYIKoyjEFzB0eCKjCggDmPEAZRBgUsFjBAYZcmAiKzKrgGBkATZJCETYiAJgQDZ1+6ufuaP7+vuouylKl39nVPd931ddfW3nHPe5zv91qn61XvOe6q7AwAAAGO0augCAAAAYEeEVgAAAEZLaAUAAGC0hFYAAABGS2gFAABgtIRWAAAARktoBWCfVlVfqKrT5j2/pKp+e4A61ldVV9Vxu7md46bbWb+b23l9Vb17d7YBAMtBaAVgVKZhqadfG6vq4qp6aVUdNKMSHpLklYtZsKqeWlU37OF6Frb5XVX12qq6tKpuraqvVNXfVtX3L3NTv5nkZ5Z5mwCwZGuGLgAAtuODSX42ydok/ynJa5IclORXtrdwVa3t7o3L0XB3f3M5trMnTEdPP5Tk/Ez2xfmZ7JfHJ/mzJA9erra6+9rl2hYA7A4jrQCM0a3d/fXuvrS735TkjUmemCRV9YjpKOzjqupTVbUhyWOn7/1oVZ1XVbdU1b9X1Qurat2WjVbVnarqHVV183SE8hcWNrzw9OCqOrSqXlVVV0y3e35VPbmqHpHkdUkOmjcyfNp0nXVV9cdVdVlV3VRV51TVYxe0c3JVXTDd5keT3GtnO6SqKsnrk1yc5Ae6+93d/eXu/lx3vzjJf16wyt2q6gPT9r9YVY9esL0frKqzp+1/o6petmBffcfpwTXxrKr6f9MR3suq6sXz3j+6qs6sqqunX++pquPnvX/sdN9fNa3pgqo6ZWefGQASI60ArAw3ZzLqOt8fJ3lWkouSXD8NhW/M5LTWs5LcNcmrk+yXZEsIfX2SuyV5VJKbkrwsyXE7anQaFN+b5PAkT0tyYZJ7J9k/yceTPDPJi5J813SVLacKv2762k8luSzJ45K8q6oe0t2frapjk/x9kr9K8hdJvifJn+5iHzwwyf2S/HR3zy18s7uvWfDSC5P8TpJfTfI/kpxZVXfr7huq6ugk/5Dkb5I8dVrra5JszmSfbs+LMhnd/a1M9u8RSU6c7qcDk3xkuk8enmRDJvv8g1V1QnfflMkp1/sn+aEk12WyHwFgl4RWAEatqk7KJPx9aMFbp3X3++ct93tJ/nd3v2760per6jlJ3lBVv5Pk+CQ/nORh3f0v03V+PpORyx15VJLvS3K/7j5/+trW5avq2iTd3V+f99p3JXlKkuO6+6vTl/+8qh6V5OmZhMhfSfLVJM/o7k5yQVXdK8kf7qSWLaOW5+9kmfle1t3vmtb0vCQ/l0nw/di0hq8l+dXu3pzk/Kp6bpK/rKrnT0PmVlV1cJL/nuSZ3X3G9OWLknxi+viUJJXkadPPk6p6epIrk/xIkrdk8seCt3X3Z6fr/PsiPwcA+zihFYAxOnk6wdGaTEZY35HkNxYsc+6C5w9OctI0qG6xKskBSe6c5IRMRhI/teXN7v5KVX1tJ3WcmOSKeYF1MR6USYD74mSgdqv9knx4+viEJJ/cEvCmPpGdq128v9Dn5j3e8hnvtKD9zfOW+ViSdUnuuWDdJLlvJvUv/MPBFg9OcvdMRrznv35gto1CvyLJq6vq5Ol23t7d5y360wCwzxJaARijs5KcmmRjkq/tYJKlGxc8X5XkD5K8dTvLzp9cqbfz/nJaNW3jIZnUP9/Nu7HdC6f/npDkXxex/Na2u7unYXIxc1nclv2zKslnMhlxXeiqaQ2vrar3ZXKq9KOSfLyqXtzdp92G9gDYh5iICYAxuqm7L+ruryxhVuBPJ7nPdL2FX5uSXJDJz72TtqxQVXdNcpedbPNfkxxVVSfs4P0NSVZvZ51Kcuft1HH5dJnzk3xvfeew5EN38fk+k+SLSX6nqha2mao6bBfrz3d+kodW1fzfAx42/Txf3sHyt+Y/Tva0xaczGaH91nY+81VbFuruy7r79O5+UpL/mckfJgBgp4RWAPYWL0jyU1X1gqq6f1Xdp6p+oqpekiTd/aUk/5jJdZvfV1UPzGRipp2Nfn4oydlJ3lZVj62qu1fVo6vqidP3L0my//S1O1bVgd19YSYTQr1+2v49qmp9Vf12Vf34dL1XZzIB1Mur6t5V9RNJfnlnH256KvHTMjnd9mNV9SM1uWfrd1fVszO5TdBivTKTsP7Kqjqhqh6f5I+S/PnC61mnbV+fyem9L66qp03bPamqttyC6I1JvpHkHVX18Ol++sGq+pMtMwhX1SumMybfY7rvT84khAPATgmtAOwVuvt9mdyv9IcyuW71U0mem8mER1s8NZMJgD6c5F1J3pRJ8NzRNjdnMnnTvyR5QyYjjq/I5NrPdPfHMwmgb87kFORnT1d9WiYzCL8kkxHedyf5wSRfma731SQ/nklw+2wmkxw9dxGf8VOZXD96wbTd86fbPinJr+9q/XnbuXz6uU7MZAT3jOlneN5OVvvdTGZsfv603bclOWa6vZumn+/iTE7PviDJ/8lk1uWrp+uvyuResl9M8oFMQu7PL7ZmAPZd9Z1zQOwdppM8vCKTU7Ze091/NHBJsENVdUmS65PMJdnU3euHrQi2qaozMpn99cruvv/0tdsn+b+ZjBRekuRJ3X31jrYBs7KD/npakl/Ktuuan9fd7x2mQpiY3vbqr5Mcmcl15Kd39yscXxmjnfTX0zKj4+teF1qn1/lcmOTRmdwb75wkT+lupyAxStPQur67vzV0LbBQVf1gJvce/et5IeAlSa7q7j+a3ibl8O5+zs62A7Owg/56WpIbuvulQ9YG81XVUUmO6u5PV9UhSc5L8sRMzgZxfGVUdtJfn5QZHV/3xtODT0pyUXdf3N0bkpyZ5AkD1wSwInX3WZnO/jrPEzI59TPTf58YGIEd9FcYne6+ors/PX18fSan3B8dx1dGaCf9dWb2xtB6dJJL5z2/LDPeqbBEneT9VXVeVZlJk5XgyO6+Yvr465mcLgRj9utV9bmqOqOqDh+6GJivqo7L5Prys+P4ysgt6K/JjI6ve2NohZXmYd39oEwmRfm16eltsCJMZ7Tdu64zYW/zqkxmXH5gkiuS/Mmw5cA2VXVwJpOaPbO7r5v/nuMrY7Od/jqz4+veGFovT3LsvOfHTF+DUdpy38buvjLJ2zPvHpIwUt+YXt+y5TqXKweuB3aou7/R3XPTmaD/Ko6xjERVrc0kALyxu/9u+rLjK6O0vf46y+Pr3hhaz0ly/PQeceuSnJLknQPXBNtVVQdNL2hPVR2U5DFJvjBsVbBL78y2W5X8fJJ3DFgL7NSWADD1X+IYywhUVSV5bZLzu/tP573l+Mro7Ki/zvL4utfNHpwkVfW4JC/P5JY3Z3T3CwcuCbarqu6RyehqkqxJ8ib9lTGpqjcneUSSO2ZyX83fT/L3Sd6S5K6Z3Hf0Sd1t8hsGt4P++ohMTl3rTG4h8vR51wzCIKrqYUk+muTzSTZPX35eJtcJOr4yKjvpr0/JjI6ve2VoBQAAYO+wN54eDAAAwF5CaAUAAGC0hFYAAABGS2gFAABgtPbq0FpVpw5dAyyGvspKor+ykuivrCT6KyvFrPvqXh1ak/jGZ6XQV1lJ9FdWEv2VlUR/ZaUQWgEAACAZ8X1a160+oA9Yc7vd2saGuZuzbvUBu1fIqhHl+pH8X80duHboEpIkc+uGrmCbXr1768/deGNWH3TQ7texZhx9JDV0AdusG8kt2TcdOI6dsvbQDbu9jY3X3py1h+7msTXJhhvG8U1cB8wNXUKSpGok379J5jbs5kFtufTuf9/M3XBDVh988G5to8bRRZIka28cRz/pVeM4ptXcOPZHkqzauHm3t7Fh001Zt+bA3drG5jXj+d11d38/Wi5j6a+9ZugKtllz5Y27tf7G3Jq12W+3tnFLbsyGvnVR/zkj2nXf6YA1t8v33+Wnhy4jfeD+Q5ewzaZx/NS87nuOGLqEJMl1x43kSJhkw6Hj+KF5653G0Ueyehz7I0nucebu/xKxHL75wN07sC+Xox731aFL2OqSjx87dAlJkv2/+5qhS0iSrF41jr6aJFd/7dChS0iS1IZx/KK57urx/Lw58pxNQ5eQJNl40DiC0X7XjuTnXpL9r7hh6BKSJBvuuHuhdzltPGQc3zsbRtJfb7nDOOpIkiP/7BNDl5CzN39w0cuOZ88BAADAAkIrAAAAoyW0AgAAMFpCKwAAAKMltAIAADBaQisAAACjJbQCAAAwWkIrAAAAoyW0AgAAMFpCKwAAAKMltAIAADBaQisAAACjNbPQWlUnV9WXquqiqnrurNoFAABg5ZpJaK2q1Un+IskPJ7lvkqdU1X1n0TYAAAAr16xGWk9KclF3X9zdG5KcmeQJM2obAACAFWpWofXoJJfOe37Z9LXvUFWnVtW5VXXuhrmbZ1QaAAAAYzWqiZi6+/TuXt/d69etPmDocgAAABjYrELr5UmOnff8mOlrAAAAsEOzCq3nJDm+qu5eVeuSnJLknTNqGwAAgBVqzSwa6e5NVfXrSd6XZHWSM7r732bRNgAAACvXTEJrknT3e5O8d1btAQAAsPKNaiImAAAAmE9oBQAAYLSEVgAAAEZLaAUAAGC0hFYAAABGS2gFAABgtIRWAAAARktoBQAAYLSEVgAAAEZLaAUAAGC01gxdwI702jXZdNThQ5eRzWtXD13CVjccu9/QJYzKqg1DV7DNIQ/69tAlJEnusGbT0CUkSa74+vDfu1t85XFrhy4hSVIbe+gSkiQXXnLnoUvY5phxfBPvd9Y4+usBj/n60CVsdf3txvF/U5ccMHQJSZJbjx3H/kiSr96lhi4hSbLu6+Oo467vH8fPvSTZvN84fq2+6U7j+LmXJIe98/NDl5AkOeDGG4cuIUly6NAFrGBGWgEAABgtoRUAAIDREloBAAAYLaEVAACA0RJaAQAAGC2hFQAAgNESWgEAABgtoRUAAIDREloBAAAYLaEVAACA0RJaAQAAGC2hFQAAgNESWgEAABgtoRUAAIDRmkloraozqurKqvrCLNoDAABg7zCrkdbXJzl5Rm0BAACwl5hJaO3us5JcNYu2AAAA2HuM6prWqjq1qs6tqnM3brxx6HIAAAAY2KhCa3ef3t3ru3v92rUHDV0OAAAAAxtVaAUAAID5hFYAAABGa1a3vHlzkk8kuXdVXVZV/20W7QIAALCyrZlFI939lFm0AwAAwN7F6cEAAACMltAKAADAaAmtAAAAjJbQCgAAwGgJrQAAAIyW0AoAAMBoCa0AAACMltAKAADAaAmtAAAAjJbQCgAAwGgJrQAAAIzWmqEL2JGNd+587Tmbhi4jB7/lwKFL2OqmI8bxN4a7nPH5oUtIkhx+1J2GLmGrH/uVTw5dQpLkL173hKFLSJIcsnHoCra5/vi5oUtIkhx+YQ1dQpLk9u8Zx/5IkprbPHQJSZJeO4468qnbDV3BVve85pahS0iSbDhy9dAlJElWbRhJH0my+uZxHGCvPf7goUtIktx49H5Dl7DVIRcP/3trkux33XiO8zc//L5Dl5Ak2XDIOI4l+10zjj6SJOved+7QJSzJOFIQAAAAbIfQCgAAwGgJrQAAAIyW0AoAAMBoCa0AAACMltAKAADAaAmtAAAAjJbQCgAAwGgJrQAAAIyW0AoAAMBoCa0AAACM1qJDa1U9papOmD6+d1WdVVUfqar77LnyAAAA2JctZaT1fyW5avr4pUk+leSfk7xyuYsCAACAJFmzhGWP6O5vVNX+SR6W5CeSbEzyrV2tWFXHJvnrJEcm6SSnd/crbkO9AAAA7EOWElq/WVX3TPLdSc7p7lur6sAktYh1NyV5Vnd/uqoOSXJeVX2gu794G2oGAABgH7GU0PqHSc5LMpfkydPXHpXks7tasbuvSHLF9PH1VXV+kqOTCK0AAADs0KJDa3e/vqreMn180/TlTyY5ZSkNVtVxSU5McvZS1gMAAGDfs9Rb3hyQ5L9W1bOnz9dkCcG3qg5O8rYkz+zu67bz/qlVdW5VnTt33U3/cQMAAADsU5Zyy5uHJ/lSkp9O8vzpy8cnedUi11+bSWB9Y3f/3faW6e7Tu3t9d69ffbsDF1saAAAAe6mljLS+PMmTu/vkTCZWSian+J60qxWrqpK8Nsn53f2nS64SAACAfdJSQutx3f2h6eOe/rshizs9+AeS/GySR1bVZ6Zfj1tC2wAAAOyDljJ78Ber6rHd/b55rz0qyed3tWJ3fyyLuzUOAAAAbLWU0PqsJO+uqvckOaCq/jLJjyZ5wh6pDAAAgH3eok8P7u5PJnlAkn9LckaSf09yUnefs4dqAwAAYB+3lJHWdPflSV6yh2oBAACA77DT0FpVf5Ntky7tUHf/3LJVBAAAAFO7Gmm9aCZVAAAAwHbsNLR29x/MqhAAAABYaEnXtFbVI5M8JcldknwtyZnz7t0KAAAAy2rRswdX1bOSnJnkqiTvSfLtJG+avg4AAADLbikjrb+V5JHd/YUtL0wnavpAkj9Z7sIAAABg0SOtUwsnZro4i5hdGAAAAG6LpYTW05K8tqqOr6oDqupeSU5P8vtVtWrL1x6pEgAAgH1SdS9uoLSqNs972klqO8+7u1cvR2EnPmBd//M/HLkcm9ot/3zLYUOXsNXjD7xl6BKSJO+88cChS0iSPPMTpwxdwlb3fulNQ5eQJKlLvzF0CUmSzTfcOHQJW/Xc3NAlJElqVe16oRlYda97DF3CVr12WX5c7LaNhx8wdAlJkpuOXDd0CVvtf/WmoUtIkqy+eRzfv2uuvnnoErYay/fNLUeN43eBWw8bx/5IksP+/nNDl5Ak2XzTOH4ngV05uz+U6/qqRf2CtJRrWu9+G+sBAACA22TRobW7v7InCwEAAICFFh1aq+rQJM9IcmKSg+e/192PWea6AAAAYEmnB781yeokb08ynos7AAAA2GstJbQ+NMkdu3vDnioGAAAA5lvKLWo+luQ+e6oQAAAAWGgpI61PTfLeqjo7yXfcV6O7X7CcRQEAAECytND6wiTHJrkkye3mvb64G70CAADAEi0ltJ6S5F7dfcWeKgYAAADmW8o1rRcn2binCgEAAICFljLS+jdJ3llVf5b/eE3rh5e1KgAAAMjSQuuvTf990YLXO8k9lqccAAAA2GbRobW7774nCwEAAICFlnJNKwAAAMzUokdaq+p2SU5L8vAkd0xSW97r7rvuYt39k5yVZL9pm3/b3b9/G+oFAABgH7KUkdZXJnlQkhckuX2S30jy1SQvW8S6tyZ5ZHc/IMkDk5xcVQ9dYq0AAADsY5YyEdNjkpzQ3d+uqrnufkdVnZvkXdlFcO3uTnLD9Ona6VffloIBAADYdyxlpHVVkmunj2+oqkOTXJHknotZuapWV9VnklyZ5APdffZ2ljm1qs6tqnO//e3NSygNAACAvdFSQutnM7meNUk+lsnpwq9KcuFiVu7uue5+YJJjkpxUVfffzjKnd/f67l5/hzuYIwoAAGBft5Rk+EtJLpk+fkaSm5McmuTnltJgd1+T5CNJTl7KegAAAOx7dhlaq+rBVXX/7r64u79cVUdkcg3rSZmcLvzVRWzjiKo6bPr4gCSPTnLB7pUOAADA3m4xI60vT3Lnec9fk+ReSf4yyf2SvGQR2zgqyUeq6nNJzsnkmtZ3L7FWAAAA9jGLmT34hCQfTZLpaOkPJ7l/d19YVe9M8vEkv7qzDXT355KcuJu1AgAAsI9ZzEjrmiQbpo8fmuTr3X1hknT3pUkO20O1AQAAsI9bTGj9tyQ/OX18SpIPbnmjqo7OttvgAAAAwLJazOnBz0nyrqp6dZK5JA+b996Tk/zLnigMAAAAdhlau/tjVXXXTCZfurC7r5/39nuSnLmnigMAAGDftpiR1kyD6nnbef1Ly14RAAAATC3mmlYAAAAYhNAKAADAaAmtAAAAjJbQCgAAwGgJrQAAAIyW0AoAAMBoVXcPXcN27XfcMX3n5z9j6DKSuRq6gq3u/ZufHbqEJElv3DR0CUmSWr166BK2qv33G7qEJEntv//QJSRJNh9356FL2GbT5qErSJLU3NzQJSRJaiT7I0l6zTj+bjp3yDi+b9Zced3QJWyzahz/N5sPHMexddVNtw5dwlabDx7HPll95bVDl5Ak2XTpZUOXANxGZ/eHcl1ftaiwNY6fSgAAALAdQisAAACjJbQCAAAwWkIrAAAAoyW0AgAAMFpCKwAAAKMltAIAADBaQisAAACjJbQCAAAwWkIrAAAAoyW0AgAAMFpCKwAAAKMltAIAADBaMw2tVbW6qv61qt49y3YBAABYmWY90vqbSc6fcZsAAACsUDMLrVV1TJLHJ3nNrNoEAABgZZvlSOvLkzw7yeYdLVBVp1bVuVV17tz1N86uMgAAAEZpJqG1qn4kyZXdfd7Oluvu07t7fXevX33IQbMoDQAAgBGb1UjrDyT5saq6JMmZSR5ZVW+YUdsAAACsUDMJrd39u919THcfl+SUJB/u7p+ZRdsAAACsXO7TCgAAwGitmXWD3f1PSf5p1u0CAACw8hhpBQAAYLSEVgAAAEZLaAUAAGC0hFYAAABGS2gFAABgtIRWAAAARktoBQAAYLSEVgAAAEZLaAUAAGC0hFYAAABGS2gFAABgtNYMXcBOreqhK8jd3zp8DVtsfvB9hi4hSbL6sxcNXUKSpDdtGrqErfrmm4cuIUmy+YYbhi4hSVJXXz10CVv13NzQJUysXj10BRPr1g1dwVarbn/40CUkSeYO2m/oEibWrR26gq1uOP7QoUtIkqy7ZhzH+TpoPP83a6+8fugSkiRzdxxHH6mR1JEkmw4dybGkhi5gm7Wf+OLQJSRJNt9yy9AlsJuMtAIAADBaQisAAACjJbQCAAAwWkIrAAAAoyW0AgAAMFpCKwAAAKMltAIAADBaQisAAACjJbQCAAAwWkIrAAAAoyW0AgAAMFpCKwAAAKMltAIAADBaa2bVUFVdkuT6JHNJNnX3+lm1DQAAwMo0s9A69UPd/a0ZtwkAAMAK5fRgAAAARmuWobWTvL+qzquqU2fYLgAAACvULE8Pflh3X15Vd0rygaq6oLvPmr/ANMyemiSr73DYDEsDAABgjGY20trdl0//vTLJ25OctJ1lTu/u9d29fvXBB82qNAAAAEZqJqG1qg6qqkO2PE7ymCRfmEXbAAAArFyzOj34yCRvr6otbb6pu/9xRm0DAACwQs0ktHb3xUkeMIu2AAAA2Hu45Q0AAACjJbQCAAAwWkIrAAAAoyW0AgAAMFpCKwAAAKMltAIAADBaQisAAACjJbQCAAAwWkIrAAAAoyW0AgAAMFpCKwAAAKMltAIAADBa1d1D17BdVfXNJF/Zzc3cMcm3lqEc2NP0VVYS/ZWVRH9lJdFfWSmWo6/erbuPWMyCow2ty6Gqzu3u9UPXAbuir7KS6K+sJPorK4n+ykox677q9GAAAABGS2gFAABgtPb20Hr60AXAIumrrCT6KyuJ/spKor+yUsy0r+7V17QCwJhU1fOS3KO7f3HoWgBgpRBaAWCZVNUN854emOTWJHPT50/v7jfOvioAWNmEVgDYA6rqkiS/2N0fHLoWAFjJ9vZrWgFgNKrqtKp6w/TxcVXVVfW0qrq0qq6uql+uqodU1eeq6pqq+vMF6/9CVZ0/XfZ9VXW3YT4JAMyO0AoAw/reJMcneXKSlyf5vSSPSnK/JE+qqocnSVU9Icnzkvx4kiOSfDTJm4coGABmSWgFgGH9YXff0t3vT3Jjkjd395XdfXkmwfTE6XK/nOTF3X1+d29K8qIkDzTaCsDeTmgFgGF9Y97jm7fz/ODp47slecX0tOFrklyVpJIcPZMqAWAga4YuAABYlEuTvNAMxADsa4y0AsDK8Ookv1tV90uSqjq0qn5y4JoAYI8z0goAK0B3v72qDk5y5vQ61muTfCDJW4etDAD2LPdpBQAAYLScHgwAAMBoCa0AAACMltAKAADAaAmtAAAAjJbQCgAAwGgJrQAAAIyW0AoAAMBoCa0AAACM1v8HryZDvFY2lKEAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 1152x265.846 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA60AAAERCAYAAACQMkH4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAF1lJREFUeJzt3Xm0ZWV5J+Dfa0FARTQoGhkUk4DiWGoFTUtrHFrQNmLSLcKyo5ks7ZhoVuw2iaYbNHHodBzTTkRZZFCJ0ahE7S7ROKGRAII4EGkaMcg8iDgylG//cU7B5VoF51TdOmffc59nrbvuPnv63nvPV4f749v729XdAQAAgCG63bwLAAAAgG0RWgEAABgsoRUAAIDBEloBAAAYLKEVAACAwRJaAQAAGCyhFQAGoqpeXVWn7+A5dq+qrqqnrFRdADBPQisAC2Uc2G7t64Q51nZkVX2yqr5dVd+tqrOr6mVVdbeVaqO7f5jknklOXqlzAsA8Ca0ALJp7Lvl6zlbWvXBrB1XVrjuzqKp6TZJ3JfnnJE9K8oAkv5fk4CS/sZJtdfel3X3dSp4TAOZFaAVgoYwD26XdfWmSa5av6+5vV9X9xqOuT6+qT1XVD5M8u6qeV1VXLj1fVR0+3nePJeseXVWnVNUPqurCqvrzpduXq6pHZxRQX9DdL+7uz3X3N7r7Y919ZJK3Ldv/WVX19aq6tqreW1U/uWTbuqp6eVV9s6quq6qzqurJS7b/2OXBVbV/VZ1YVVdV1fer6oyqOnTJ9l+uqjOr6odVdX5VHbs0xFfVM6rqy+Of9+qq+kRV3XWqNwYAtpPQCsBa9uokr8totPMjkxxQVQ9P8r+TvCfJg5IcmeSRSd56K4c9M8m3siycbtHd1yx5ed8kvzj+enKSn09y7JLtL07ygiQvSvLgJJuSfLCqDt5GvXsm+UySn0ry1PExr0pS4+1PTXJ8Rr+H+yfZmORXkhwz3n7vJO8c/3wHJ3lMkhNv5WcFgBW1y7wLAIA5em13f2DLi6qa5JjfT3JCd79x/Pq8qvqdJP9UVb/V3ddu5ZgDk/zf7t48YV2/2t3fG9d0fJJfWrLtvyR5ZXf/7ZZ6quoxGYXY39zKuZ6d5C5JnrYkHJ+3ZPsfJXlFd//V+PX5VfXSJG8eb9s3o//J/d7x6HWSfGnCnwMAdpjQCsBatj0z9T48yX5V9ewl67ak3Z9JcuZWjpkoDY+dvyWwjl2c5O5JUlV3T7JXks8uO+aUJP9mG+d7aJIzlo3mZny+Gm9/UFUds2TT7ZLcfnxZ8mkZjdR+rao+mtEET+/r7qum+JkAYLsJrQCsZd9b9vpH+fGAuXyCptsleVNGI5HLXbiNds5N8vSqWjfBaOsNy153JrudpyfYZ7kan/uPknxwK9uv7e7NVfXYjC5TfmKS/5zk1VX1qO4+ZzvaBICpuKcVAG52RZK7VNXuS9atX7bPF5I8oLvP28rXtmbsfVeSn0zy3K1trKq7TFJcd1+e5Kokj1q26dAkX93GYWcmedjW2ujuHyU5K8lB2/h5Nm/Zr7s/293HZDTS/K0kT5+kZgDYUUZaAeBmn0tyfZJXVdWbMgpoz1m2zyuTfK6q3pjRBEbfy2iCosO6+/lbO2l3f2q8/xvGExt9IKPLfn8mo4mPzkzyPyas8c+SvKSqvp7ki0l+fVznr21j/79K8l+TfGB8r+olSR6S5Mru/kySlyV5X1VdlOR9GY02PyjJ+u5+SVX924xC8clJLk/ycxk9OmhbIRkAVpSRVgAY6+7Lkjwro5l7vzRe/u/L9jkjoxl0D87oXtIzk/xJkktzK7r7hePz/XxGM/5+NckbknwtyV9MUeb/TPLGJK9P8uWMnvn6tG1dqtvd307y6CRXZjRD8peSvDSjcJruPinJEUkOz+ge389nNNnTN8anuCbJL4yPPTejmYdf2t3vnaJmANhu1b09t8AMX1UdntEfA+uSvL27Xz3nkmC7VNUFSb6TZHOSG7t7w3wrgts2nvH2KUku7+4HjtftleRvkxyQ5IIkR3b3t+ZVI0xiG3352IxG4K8Y7/aS7p7okUkwD1W1f0ZXXdwjo/vfj+vuN/hcZrVYyJHWqlqX0SQZT8romXNHV9X951sV7JDHdvd6gZVV5ISMRu6W+oMkH+/uA5N8fPwahu6E/HhfTpLXjT+X1wusrAI3JnlRd98/o+dKP3/8t7HPZVaFhQytSQ5Jcl53n9/d12f0EPQj5lwTwJrR3Z9OcvWy1Uck+cvx8l8medpMi4LtsI2+DKtKd1/S3V8YL38nyTkZPYPZ5zKrwqKG1n1zy8cOfHO8DlajTvLRqjqjqjbOuxjYAffo7kvGy5dmdJkarFa/XVVnV9Xx4+fZwqpQVQdk9HzmU+NzmVViUUMrLJJDu/thGV3u/vyqevS8C4Id1aMJFRZzUgXWgrdkNPPz+oxmY37NfMuByVTVHhnNEv673X3t0m0+lxmyRQ2tFyXZf8nr/cbrYNXp7ovG3y9P8v6MLn+H1eiyqrpnkoy/Xz7nemC7dPdl3b15/Jzbv4jPZVaBqto1o8D6zu7++/Fqn8usCosaWk9LcmBV3aeqfiLJUUlOmnNNMLWqumNV3WnLcpInZvSIC1iNTkry7PHys5N8cI61wHbb8kf+2C/F5zIDV1WV5B1Jzunu1y7Z5HOZVWGRH3nz5IyeYbcuyfHd/Yo5lwRTq6qfzmh0NUl2SfIufZnVoKrendGzPe+W5LIkxyT5QJL3JLlXRs8APbK7TXDDoG2jL/9CRpcGd0aPCXnukvsCYXCq6tAkn8noOc0/Gq9+SUb3tfpcZvAWNrQCAACw+i3q5cEAAAAsAKEVAACAwRJaAQAAGCyhFQAAgMESWgEAABishQ6tVbVx3jXAStCXWRT6MotAP2ZR6MusFgsdWpP4h8ii0JdZFPoyi0A/ZlHoy6wKix5aAQAAWMWqu+ddw1b9RO3Wu+eOO3SOG3Jdds1uK1QRzM+i9uWDHvz9eZdwC+eefYd5lzBYK/VeXXHV5ux913U7fB7vFdtrJfqyfjwb/hux8y3q3xds25D+XV1w4Q258urNNcm+gw2te9Ze/Yh6/LzLAHaiTRefNe8SbuGwfdbPu4TB8l6xKIbUl/XjWzek9yrxfrEYhvTv6pDDLszpX/zhRKHV5cEAAAAMltAKAADAYAmtAAAADJbQCgAAwGAJrQAAAAyW0AoAAMBgCa0AAAAMltAKAADAYAmtAAAADJbQCgAAwGAJrQAAAAyW0AoAAMBgzSy0VtXhVfW1qjqvqv5gVu0CAACwes0ktFbVuiRvSvKkJPdPcnRV3X8WbQMAALB6zWqk9ZAk53X3+d19fZITkxwxo7YBAABYpWYVWvdNcuGS198cr7uFqtpYVadX1ek35LoZlQYAAMBQDWoipu4+rrs3dPeGXbPbvMsBAABgzmYVWi9Ksv+S1/uN1wEAAMA2zSq0npbkwKq6T1X9RJKjkpw0o7YBAABYpXaZRSPdfWNV/XaSTUnWJTm+u78yi7YBAABYvWYSWpOkuz+S5COzag8AAIDVb1ATMQEAAMBSQisAAACDJbQCAAAwWEIrAAAAgyW0AgAAMFhCKwAAAIMltAIAADBYQisAAACDJbQCAAAwWEIrAAAAgyW0AgAAMFhCKwAAAINV3T3vGrZqz9qrH1GPn3cZAAzQpovPmncJt3DYPuvnXcJN/G5WD+8VsJad2h/PtX11TbKvkVYAAAAGS2gFAABgsIRWAAAABktoBQAAYLCEVgAAAAZLaAUAAGCwhFYAAAAGS2gFAABgsIRWAAAABktoBQAAYLCEVgAAAAZLaAUAAGCwhFYAAAAGayahtaqOr6rLq+rLs2gPAACAxTCrkdYTkhw+o7YAAABYEDMJrd396SRXz6ItAAAAFod7WgEAABisXeZdwFJVtTHJxiTZPXeYczUAAADM26BGWrv7uO7e0N0bds1u8y4HAACAORtUaAUAAIClZvXIm3cn+ack962qb1bVb8yiXQAAAFa3mdzT2t1Hz6IdAAAAFovLgwEAABgsoRUAAIDBEloBAAAYLKEVAACAwRJaAQAAGCyhFQAAgMESWgEAABgsoRUAAIDBEloBAAAYLKEVAACAwRJaAQAAGCyhFQAAgMESWgEAABisXeZdwLYc9ODvZ9Oms+Zdxk0O22f9vEu4yaaLh/N7SYb1uxmiIb1f3qtbN6T3amiG1neGVg9sD/0YYDJGWgEAABgsoRUAAIDBEloBAAAYLKEVAACAwRJaAQAAGCyhFQAAgMESWgEAABgsoRUAAIDBEloBAAAYLKEVAACAwRJaAQAAGCyhFQAAgMGaOLRW1dFVdfB4+b5V9emq+kRV3W/nlQcAAMBaNs1I658kuXq8/GdJ/jnJp5K8+bYOrKr9xwH3q1X1lap64fSlAgAAsNbsMsW+e3f3ZVW1e5JDk/zHJDckuXKCY29M8qLu/kJV3SnJGVV1cnd/dfqSAQAAWCumCa1XVNXPJnlQktO6+7qqukOSuq0Du/uSJJeMl79TVeck2TeJ0AoAAMA2TRNa/zjJGUk2J3nGeN0Tknxxmgar6oAkD01y6jTHAQAAsPZMHFq7+4Sqes94+fvj1Z9PctSk56iqPZK8L8nvdve1W9m+McnGJLnXvtPkaQAAABbRtI+8uX2S/1BVLx6/3iUTBt+q2jWjwPrO7v77re3T3cd194bu3rD3XddNWRoAAACLZppH3jwmydeSPDPJfxuvPjDJWyY4tpK8I8k53f3a7agTAACANWiakdbXJ3lGdx+e0WzAyei+1EMmOPZRSX4lyeOq6qzx15OnKxUAAIC1ZpobRw/o7o+Pl3v8/fpJztHdp2SCWYYBAABgqWlGWr9aVYctW/eEJF9awXoAAADgJtOMtL4oyYeq6sNJbl9Vb0vyi0mO2CmVAQAAsOZNPNLa3Z9P8pAkX0lyfJKvJzmku0/bSbUBAACwxk31MNTuvijJn+6kWgAAAOAWbjW0VtVf5+ZJl7apu5+1YhUBAADA2G2NtJ43kyoAAABgK241tHb3y2ZVCAAAACw31T2tVfW4JEcn2SfJxUlOXPLsVgAAAFhRE88eXFUvSnJikquTfDjJVUneNV4PAAAAK26akdbfS/K47v7ylhXjiZpOTvKalS4MAAAAJh5pHVs+MdP5mWB2YQAAANge04TWY5O8o6oOrKrbV9VBSY5LckxV3W7L106pEgAAgDVpmsuD3zb+fnRGo6s1fv3M8bYar1+3YtWxKmy6+Kx5lwAr4rB91s+7BAAAlpkmtN5np1UBAAAAWzFxaO3ub+zMQgAAAGC5iUNrVd05yQuSPDTJHku3dfcTV7guAAAAmOry4L/L6H7V9yf5wc4pBwAAAG42TWh9ZJK7dff1O6sYAAAAWGqaR9SckuR+O6sQAAAAWG6akdZfTfKRqjo1yWVLN3T3y1eyKAAAAEimC62vSLJ/kguS7Llkfa9kQQAAALDFNKH1qCQHdfclO6sYAAAAWGqae1rPT3LDzioEAAAAlptmpPWvk5xUVX+eH7+n9R9XtCoAAADIdKH1+ePvr1y2vpP89MqUAwAAADebOLR29312ZiEAAACw3DT3tAIAAMBMTTzSWlV7Jjk2yWOS3C1JbdnW3fe6jWN3T/LpJLuN23xvdx+zHfUCAACwhkwz0vrmJA9L8vIkeyX5nST/muR1Exx7XZLHdfdDkqxPcnhVPXLKWgEAAFhjppmI6YlJDu7uq6pqc3d/sKpOT/IPuY3g2t2d5Lvjl7uOv3p7CgYAAGDtmGak9XZJvj1e/m5V3TnJJUl+dpKDq2pdVZ2V5PIkJ3f3qVNVCgAAwJozTWj9Ykb3sybJKRldLvyWJOdOcnB3b+7u9Un2S3JIVT1w+T5VtbGqTq+q06+4avMUpQEAALCIpgmtz0lywXj5BUl+kOTOSZ41TYPdfU2STyQ5fCvbjuvuDd29Ye+7rpvmtAAAACyg2wytVfXwqnpgd5/f3f+vqvbO6B7WQzK6XPhfJzjH3lV1l/Hy7ZP8uyT/smOlAwAAsOgmGWl9fZKfWvL67UkOSvK2JA9I8qcTnOOeST5RVWcnOS2je1o/NGWtAAAArDGTzB58cJLPJMl4tPRJSR7Y3edW1UlJPpfkt27tBN19dpKH7mCtAAAArDGTjLTukuT68fIjk1za3ecmSXdfmOQuO6k2AAAA1rhJQutXkjx9vHxUko9t2VBV++bmx+AAAADAiprk8uDfT/IPVfXWJJuTHLpk2zOSfHZnFAYAAAC3GVq7+5SquldGky+d293fWbL5w0lO3FnFAQAAsLZNMtKacVA9Yyvrv7biFQEAAMDYJPe0AgAAwFwIrQAAAAyW0AoAAMBgCa0AAAAMltAKAADAYAmtAAAADJbQCgAAwGBVd8+7hq3as/bqR9Tj513GTTZdfNa8S7jJYfusn3cJAAAA2+3U/niu7atrkn2NtAIAADBYQisAAACDJbQCAAAwWEIrAAAAgyW0AgAAMFhCKwAAAIMltAIAADBYQisAAACDJbQCAAAwWEIrAAAAgyW0AgAAMFhCKwAAAIMltAIAADBYQisAAACDNdPQWlXrqurMqvrQLNsFAABgdZr1SOsLk5wz4zYBAABYpWYWWqtqvyT/PsnbZ9UmAAAAq9ssR1pfn+TFSX60rR2qamNVnV5Vp9+Q62ZXGQAAAIM0k9BaVU9Jcnl3n3Fr+3X3cd29obs37JrdZlEaAAAAAzarkdZHJXlqVV2Q5MQkj6uqv5lR2wAAAKxSMwmt3f2H3b1fdx+Q5Kgk/9jd/2kWbQMAALB6eU4rAAAAg7XLrBvs7k8m+eSs2wUAAGD1MdIKAADAYAmtAAAADJbQCgAAwGAJrQAAAAyW0AoAAMBgCa0AAAAMltAKAADAYAmtAAAADJbQCgAAwGAJrQAAAAyW0AoAAMBgCa0AAAAM1i7zLmBbDnrw97Np01nzLuMmh+2zft4l3GTTxcP5vSTD+t0AADBb/jZlZzPSCgAAwGAJrQAAAAyW0AoAAMBgCa0AAAAMltAKAADAYAmtAAAADJbQCgAAwGAJrQAAAAyW0AoAAMBgCa0AAAAMltAKAADAYAmtAAAADJbQCgAAwGDtMquGquqCJN9JsjnJjd29YVZtAwAAsDrNLLSOPba7r5xxmwAAAKxSLg8GAABgsGYZWjvJR6vqjKraOMN2AQAAWKVmeXnwod19UVXdPcnJVfUv3f3ppTuMw+zGJLnXvrO+chkAAIChmdlIa3dfNP5+eZL3JzlkK/sc190bunvD3nddN6vSAAAAGKiZhNaqumNV3WnLcpInJvnyLNoGAABg9ZrVNbj3SPL+qtrS5ru6+//MqG0AAABWqZmE1u4+P8lDZtEWAAAAi8MjbwAAABgsoRUAAIDBEloBAAAYLKEVAACAwRJaAQAAGCyhFQAAgMESWgEAABgsoRUAAIDBEloBAAAYLKEVAACAwRJaAQAAGCyhFQAAgMESWgEAABis6u5517BVVXVFkm/s4GnuluTKFSgH5k1fZlHoyywC/ZhFoS8zT/fu7r0n2XGwoXUlVNXp3b1h3nXAjtKXWRT6MotAP2ZR6MusFi4PBgAAYLCEVgAAAAZr0UPrcfMuAFaIvsyiWNN9uapeUlVvn3cd7LA13Y9ZKPoyq8JC39MKALNUVd9d8vIOSa5Lsnn8+rnd/c7ZVwUAq5vQCgA7QVVdkOQ3u/tj864FAFazRb88GAAGo6qOraq/GS8fUFVdVb9WVRdW1beq6nlV9XNVdXZVXVNV/2vZ8b9eVeeM991UVfeez08CALMjtALAfD0iyYFJnpHk9UlemuQJSR6Q5MiqekySVNURSV6S5JeT7J3kM0nePY+CAWCWhFYAmK8/7u4fdvdHk3wvybu7+/LuviijYPrQ8X7PS/Kq7j6nu29M8sok6422ArDohFYAmK/Lliz/YCuv9xgv3zvJG8aXDV+T5OoklWTfmVQJAHOyy7wLAAAmcmGSV5iBGIC1xkgrAKwOb03yh1X1gCSpqjtX1dPnXBMA7HRGWgFgFeju91fVHklOHN/H+u0kJyf5u/lWBgA7l+e0AgAAMFguDwYAAGCwhFYAAAAGS2gFAABgsIRWAAAABktoBQAAYLCEVgAAAAZLaAUAAGCwhFYAAAAGS2gFAABgsP4/WNqU855zUE0AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 1152x276.48 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# alphas (predicted choice) and true choices\n",
    "rand_indexes = np.random.choice(n_data, 6)\n",
    "source_batch = source_data[rand_indexes]\n",
    "target_batch = target_data[rand_indexes]\n",
    "emotions = category_data[rand_indexes]\n",
    "\n",
    "choice_preds = sess.run(alphas,\n",
    "                        feed_dict={\n",
    "                            source_ids: source_batch,\n",
    "                            target_ids: target_batch,\n",
    "                            emo_cat: emotions})\n",
    "choice_batch = choice_data[rand_indexes]\n",
    "\n",
    "plt.figure()\n",
    "plt.matshow(choice_preds)\n",
    "plt.title(\"Predicted Choices\", fontsize=14)\n",
    "plt.xlabel(\"Time\", fontsize=12)\n",
    "plt.ylabel(\"Samples\", fontsize=12)\n",
    "plt.savefig(\"./predict_choice\")\n",
    "\n",
    "plt.figure()\n",
    "plt.matshow(choice_batch)\n",
    "plt.title(\"True Choices\", fontsize=14)\n",
    "plt.xlabel(\"Time\", fontsize=12)\n",
    "plt.ylabel(\"Samples\", fontsize=12)\n",
    "plt.savefig(\"./true_choice\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading inference data ...\n",
      "\tDone.\n"
     ]
    }
   ],
   "source": [
    "print(\"Loading inference data ...\")\n",
    "\n",
    "# id_0, id_1, id_2 preserved for SOS, EOS, constant zero padding\n",
    "embed_shift = 3\n",
    "filename = config[\"inference\"][\"infer_source_file\"]\n",
    "c_filename = config[\"inference\"][\"infer_category_file\"]\n",
    "max_leng = config[\"inference\"][\"infer_source_max_length\"]\n",
    "\n",
    "source_data = loadfile(filename, is_source=True,\n",
    "                       max_length=max_leng) + embed_shift\n",
    "category_data = pd.read_csv(\n",
    "    c_filename, header=None, index_col=None, dtype=int)[0].values\n",
    "\n",
    "print(\"\\tDone.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Start inferring ...\n",
      "\tDone.\n"
     ]
    }
   ],
   "source": [
    "# Inference\n",
    "print(\"Start inferring ...\")\n",
    "final_result = []\n",
    "n_data = source_data.shape[0]\n",
    "n_pad = n_data % infer_batch_size\n",
    "if n_pad > 0:\n",
    "    n_pad = infer_batch_size - n_pad\n",
    "\n",
    "pad = np.zeros((n_pad, max_leng), dtype=np.int32)\n",
    "source_data = np.concatenate((source_data, pad))\n",
    "category_data = np.concatenate((category_data, np.zeros(n_pad)))\n",
    "\n",
    "for ith in range(int(len(source_data) / infer_batch_size)):\n",
    "    start = ith * infer_batch_size\n",
    "    end = (ith + 1) * infer_batch_size\n",
    "    batch = source_data[start:end]\n",
    "    batch_cat = category_data[start:end]\n",
    "\n",
    "    result = sess.run(infer_outputs,\n",
    "                      feed_dict={source_ids: batch, emo_cat: batch_cat})\n",
    "    result = result.ids[:, :, 0]\n",
    "\n",
    "    if result.shape[1] < max_iter:\n",
    "        l_pad = max_iter - result.shape[1]\n",
    "        result = np.concatenate(\n",
    "            (result, np.ones((infer_batch_size, l_pad))), axis=1)\n",
    "\n",
    "    final_result.append(result)\n",
    "\n",
    "final_result = np.concatenate(final_result)[:n_data] - embed_shift\n",
    "choice_pred = (final_result >= vocab_size).astype(np.int)\n",
    "final_result[final_result >= vocab_size] -= (vocab_size + embed_shift)\n",
    "\n",
    "# transform to output format\n",
    "final_result[final_result < 0] = -1\n",
    "final_result = (final_result.astype(int)).astype(str).tolist()\n",
    "final_result = list(map(lambda t: \" \".join(t), final_result))\n",
    "\n",
    "choice_pred = choice_pred.astype(str).tolist()\n",
    "choice_pred = list(map(lambda t: \" \".join(t), choice_pred))\n",
    "\n",
    "df = pd.DataFrame(data={\"0\": final_result})\n",
    "df.to_csv(config[\"inference\"][\"output_path\"], header=None, index=None)\n",
    "\n",
    "cdf = pd.DataFrame(data={\"0\": choice_pred})\n",
    "cdf.to_csv(config[\"inference\"][\"choice_path\"], header=None, index=None)\n",
    "print(\"\\tDone.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "24.403917003458464"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# perplexity\n",
    "random_indexes = np.random.choice(len(dev_source_data), 256)\n",
    "s = dev_source_data[random_indexes]\n",
    "t = dev_target_data[random_indexes]\n",
    "q = dev_choice_data[random_indexes]\n",
    "c = dev_category_data[random_indexes]\n",
    "\n",
    "m = (t != -1)\n",
    "\n",
    "feed_dict = {\n",
    "    source_ids: s,\n",
    "    target_ids: t,\n",
    "    choice_qs: q,\n",
    "    emo_cat: c,\n",
    "    sequence_mask: m,\n",
    "}\n",
    "\n",
    "compute_perplexity(sess, CE, m, feed_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
